summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJulius Ikkala <julius.ikkala@gmail.com>2025-05-03 23:27:03 +0300
committerGitHub <noreply@github.com>2025-05-03 23:27:03 +0300
commit6f6103c4dbc77d5bceae7c8e766ec3cabc293364 (patch)
tree5d00a97065771ff3dd95b837e6e005512797487e
parent7f9283a34b4aaf3401cdb652a2f9208b2b4ff4f4 (diff)
Add IREnumType to distinguish enums from ints and each other (#6973)
* Add IREnumType to distinguish enums from ints and each other * Add issue example as test * format code * Add expected test output * Fix peephole optimization hanging No idea why this PR triggered this, but there seems to have been a clear bug here anyway, so may just as well fix it now. * Move enum lowering later * Add linkage decoration to enum type * Use filecheck-buffer instead of expected.txt * Fix comment * Make enum casts actually use IR enum casts They were all BuiltinCasts by accident * Lower enum type before VM * Deal with rate-qualified types in enum cast * Allow any value marshalling for enum types * Handle new enum instructions in a couple more switches * Fix formatting --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
-rw-r--r--source/slang/slang-emit.cpp8
-rw-r--r--source/slang/slang-ir-any-value-marshalling.cpp13
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp3
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h3
-rw-r--r--source/slang/slang-ir-layout.cpp7
-rw-r--r--source/slang/slang-ir-lower-enum-type.cpp149
-rw-r--r--source/slang/slang-ir-lower-enum-type.h14
-rw-r--r--source/slang/slang-ir-util.cpp8
-rw-r--r--source/slang/slang-ir.cpp42
-rw-r--r--source/slang/slang-ir.h8
-rw-r--r--source/slang/slang-lower-to-ir.cpp9
-rw-r--r--tests/bugs/gh-6964.slang58
13 files changed, 317 insertions, 9 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index aa7387e22..ac6336e9c 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -72,6 +72,7 @@
#include "slang-ir-lower-combined-texture-sampler.h"
#include "slang-ir-lower-coopvec.h"
#include "slang-ir-lower-dynamic-resource-heap.h"
+#include "slang-ir-lower-enum-type.h"
#include "slang-ir-lower-generics.h"
#include "slang-ir-lower-glsl-ssbo-types.h"
#include "slang-ir-lower-l-value-cast.h"
@@ -312,6 +313,7 @@ struct RequiredLoweringPassSet
bool debugInfo;
bool resultType;
bool optionalType;
+ bool enumType;
bool combinedTextureSamplers;
bool reinterpret;
bool generics;
@@ -356,6 +358,9 @@ void calcRequiredLoweringPassSet(
case kIROp_OptionalType:
result.optionalType = true;
break;
+ case kIROp_EnumType:
+ result.enumType = true;
+ break;
case kIROp_TextureType:
if (!isKhronosTarget(codeGenContext->getTargetReq()))
{
@@ -1159,6 +1164,9 @@ Result linkAndOptimizeIR(
cleanupGenerics(targetProgram, irModule, sink);
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS");
+ if (requiredLoweringPassSet.enumType)
+ lowerEnumType(irModule, sink);
+
// Don't need to run any further target-dependent passes if we are generating code
// for host vm.
if (target == CodeGenTarget::HostVM)
diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp
index d124b293d..b3bcf3316 100644
--- a/source/slang/slang-ir-any-value-marshalling.cpp
+++ b/source/slang/slang-ir-any-value-marshalling.cpp
@@ -155,6 +155,13 @@ struct AnyValueMarshallingContext
case kIROp_PtrType:
context->marshalBasicType(builder, dataType, concreteTypedVar);
break;
+ case kIROp_EnumType:
+ {
+ auto enumType = static_cast<IREnumType*>(dataType);
+ auto tagType = enumType->getTagType();
+ context->marshalBasicType(builder, tagType, concreteTypedVar);
+ break;
+ }
case kIROp_VectorType:
{
auto vectorType = static_cast<IRVectorType*>(dataType);
@@ -868,6 +875,12 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
case kIROp_UInt8Type:
case kIROp_Int8Type:
return offset + 1;
+ case kIROp_EnumType:
+ {
+ auto enumType = static_cast<IREnumType*>(type);
+ auto tagType = enumType->getTagType();
+ return _getAnyValueSizeRaw(tagType, offset);
+ }
case kIROp_VectorType:
{
auto vectorType = static_cast<IRVectorType*>(type);
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 08ca56f7d..febd30d7e 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -2157,6 +2157,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_GetNativePtr:
case kIROp_CastIntToFloat:
case kIROp_CastFloatToInt:
+ case kIROp_CastIntToEnum:
+ case kIROp_CastEnumToInt:
+ case kIROp_EnumCast:
case kIROp_DetachDerivative:
case kIROp_GetSequentialID:
case kIROp_GetStringHash:
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 8fb1bc9ac..66505dbd0 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -60,6 +60,7 @@ INST(Nop, nop, 0, 0)
INST(AttributedType, Attributed, 0, HOISTABLE)
INST(ResultType, Result, 2, HOISTABLE)
INST(OptionalType, Optional, 1, HOISTABLE)
+ INST(EnumType, Enum, 1, PARENT)
INST(DifferentialPairType, DiffPair, 1, HOISTABLE)
INST(DifferentialPairUserCodeType, DiffPairUserCode, 1, HOISTABLE)
@@ -1232,6 +1233,9 @@ INST(CastPtrToInt, CastPtrToInt, 1, 0)
INST(CastIntToPtr, CastIntToPtr, 1, 0)
INST(CastToVoid, castToVoid, 1, 0)
INST(PtrCast, PtrCast, 1, 0)
+INST(CastEnumToInt, CastEnumToInt, 1, 0)
+INST(CastIntToEnum, CastIntToEnum, 1, 0)
+INST(EnumCast, EnumCast, 1, 0)
INST(CastUInt2ToDescriptorHandle, CastUInt2ToDescriptorHandle, 1, 0)
INST(CastDescriptorHandleToUInt2, CastDescriptorHandleToUInt2, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 21d5e1c23..5b196cf24 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4299,6 +4299,9 @@ public:
// Create an initially empty `class` type.
IRClassType* createClassType();
+ // Create an an `enum` type with the given tag type.
+ IREnumType* createEnumType(IRType* tagType);
+
// Create an initially empty `GLSLShaderStorageBufferType` type.
IRGLSLShaderStorageBufferType* createGLSLShaderStorableBufferType();
IRGLSLShaderStorageBufferType* createGLSLShaderStorableBufferType(
diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp
index df643a3c1..7b3aa4340 100644
--- a/source/slang/slang-ir-layout.cpp
+++ b/source/slang/slang-ir-layout.cpp
@@ -377,6 +377,13 @@ static Result _calcSizeAndAlignment(
attributedType->getBaseType(),
outSizeAndAlignment);
}
+ case kIROp_EnumType:
+ {
+ auto enumType = cast<IREnumType>(type);
+ auto tagType = enumType->getTagType();
+ return _calcSizeAndAlignment(optionSet, rules, tagType, outSizeAndAlignment);
+ }
+ break;
default:
break;
}
diff --git a/source/slang/slang-ir-lower-enum-type.cpp b/source/slang/slang-ir-lower-enum-type.cpp
new file mode 100644
index 000000000..548c29a51
--- /dev/null
+++ b/source/slang/slang-ir-lower-enum-type.cpp
@@ -0,0 +1,149 @@
+// slang-ir-lower-enum-type.cpp
+
+#include "slang-ir-lower-enum-type.h"
+
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+struct EnumTypeLoweringContext
+{
+ IRModule* module;
+ DiagnosticSink* sink;
+
+ InstWorkList workList;
+ InstHashSet workListSet;
+
+ IRGeneric* genericOptionalStructType = nullptr;
+ IRStructKey* valueKey = nullptr;
+ IRStructKey* hasValueKey = nullptr;
+
+ EnumTypeLoweringContext(IRModule* inModule)
+ : module(inModule), workList(inModule), workListSet(inModule)
+ {
+ }
+
+ struct LoweredEnumTypeInfo : public RefObject
+ {
+ IRType* enumType = nullptr;
+ IRType* loweredType = nullptr;
+ };
+ Dictionary<IRInst*, RefPtr<LoweredEnumTypeInfo>> loweredEnumTypes;
+
+ void addToWorkList(IRInst* inst)
+ {
+ if (workListSet.contains(inst))
+ return;
+
+ workList.add(inst);
+ workListSet.add(inst);
+ }
+
+ LoweredEnumTypeInfo* getLoweredEnumType(IRInst* type)
+ {
+ if (auto loweredInfo = loweredEnumTypes.tryGetValue(type))
+ return loweredInfo->Ptr();
+
+ if (!type)
+ return nullptr;
+
+ if (type->getOp() != kIROp_EnumType)
+ return nullptr;
+
+ RefPtr<LoweredEnumTypeInfo> info = new LoweredEnumTypeInfo();
+ auto enumType = cast<IREnumType>(type);
+ auto valueType = enumType->getTagType();
+ info->enumType = (IRType*)type;
+ info->loweredType = valueType;
+ loweredEnumTypes[type] = info;
+ return info.Ptr();
+ }
+
+ void processEnumType(IREnumType* inst)
+ {
+ auto loweredEnumTypeInfo = getLoweredEnumType(inst);
+ SLANG_ASSERT(loweredEnumTypeInfo);
+ SLANG_UNUSED(loweredEnumTypeInfo);
+ }
+
+ void processEnumCast(IRInst* inst)
+ {
+ IRBuilder builderStorage(module);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto value = inst->getOperand(0);
+ if (auto enumType = getLoweredEnumType(value->getDataType()))
+ {
+ auto rate = value->getRate();
+ auto type = enumType->loweredType;
+ if (rate)
+ {
+ type = builder->getRateQualifiedType(rate, type);
+ }
+
+ value->setFullType(type);
+ }
+
+ auto type = inst->getDataType();
+ if (auto enumType = getLoweredEnumType(type))
+ { // Cast was into enum, so use tag type instead.
+ type = enumType->loweredType;
+ }
+
+ auto cast = builder->emitCast(type, value);
+
+ inst->replaceUsesWith(cast);
+ inst->removeAndDeallocate();
+ }
+
+ void processInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_EnumType:
+ processEnumType((IREnumType*)inst);
+ break;
+ case kIROp_CastEnumToInt:
+ case kIROp_CastIntToEnum:
+ case kIROp_EnumCast:
+ processEnumCast(inst);
+ break;
+ default:
+ break;
+ }
+ }
+
+ void processModule()
+ {
+ addToWorkList(module->getModuleInst());
+
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = workList.getLast();
+ workList.removeLast();
+ workListSet.remove(inst);
+
+ processInst(inst);
+
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ addToWorkList(child);
+ }
+ }
+
+ // Replace all enum types with their lowered equivalent types.
+ for (const auto& [key, value] : loweredEnumTypes)
+ key->replaceUsesWith(value->loweredType);
+ }
+};
+
+void lowerEnumType(IRModule* module, DiagnosticSink* sink)
+{
+ EnumTypeLoweringContext context(module);
+ context.sink = sink;
+ context.processModule();
+}
+} // namespace Slang
diff --git a/source/slang/slang-ir-lower-enum-type.h b/source/slang/slang-ir-lower-enum-type.h
new file mode 100644
index 000000000..37a465bcb
--- /dev/null
+++ b/source/slang/slang-ir-lower-enum-type.h
@@ -0,0 +1,14 @@
+// slang-ir-lower-enum-type.h
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+struct IRModule;
+class DiagnosticSink;
+
+/// Lower `IREnumType` to their underlying integer types.
+void lowerEnumType(IRModule* module, DiagnosticSink* sink);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 36aec22b5..5aae53747 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -106,6 +106,7 @@ IROp getTypeStyle(IROp op)
{
case kIROp_VoidType:
case kIROp_BoolType:
+ case kIROp_EnumType:
{
return op;
}
@@ -220,6 +221,7 @@ bool isValueType(IRInst* dataType)
case kIROp_FuncType:
case kIROp_RaytracingAccelerationStructureType:
case kIROp_GLSLAtomicUintType:
+ case kIROp_EnumType:
return true;
default:
// Read-only resource handles are considered as Value type.
@@ -271,6 +273,12 @@ bool isSimpleDataType(IRType* type)
case kIROp_AnyValueType:
case kIROp_PtrType:
return true;
+ case kIROp_EnumType:
+ {
+ auto enumType = as<IREnumType>(type);
+ auto tagType = enumType->getTagType();
+ return isSimpleDataType(tagType);
+ }
case kIROp_ArrayType:
case kIROp_UnsizedArrayType:
return isSimpleDataType((IRType*)type->getOperand(0));
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index c105a698a..c1cec36fe 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4051,6 +4051,7 @@ enum class TypeCastStyle
Float,
Bool,
Ptr,
+ Enum,
Void
};
static TypeCastStyle _getTypeStyleId(IRType* type)
@@ -4083,6 +4084,8 @@ static TypeCastStyle _getTypeStyleId(IRType* type)
case kIROp_RefType:
case kIROp_ConstRefType:
return TypeCastStyle::Ptr;
+ case kIROp_EnumType:
+ return TypeCastStyle::Enum;
case kIROp_VoidType:
return TypeCastStyle::Void;
default:
@@ -4131,27 +4134,42 @@ IRInst* IRBuilder::emitCast(IRType* type, IRInst* value, bool fallbackToBuiltinC
}
};
- static const OpSeq opMap[4][5] = {
- /* To: Int, Float, Bool, Ptr, Void*/
+ static const OpSeq opMap[5][6] = {
+ /* To: Int, Float, Bool, Ptr, Enum, Void */
/* From Int */ {
kIROp_IntCast,
kIROp_CastIntToFloat,
kIROp_IntCast,
kIROp_CastIntToPtr,
+ kIROp_CastIntToEnum,
kIROp_CastToVoid},
/* From Float */
{kIROp_CastFloatToInt,
kIROp_FloatCast,
{kIROp_Neq},
{kIROp_CastFloatToInt, kIROp_CastIntToPtr},
+ {kIROp_CastFloatToInt, kIROp_CastIntToEnum},
kIROp_CastToVoid},
/* From Bool */
- {kIROp_IntCast, kIROp_CastIntToFloat, kIROp_Nop, kIROp_CastIntToPtr, kIROp_CastToVoid},
+ {kIROp_IntCast,
+ kIROp_CastIntToFloat,
+ kIROp_Nop,
+ kIROp_CastIntToPtr,
+ kIROp_CastIntToEnum,
+ kIROp_CastToVoid},
/* From Ptr */
{kIROp_CastPtrToInt,
{kIROp_CastPtrToInt, kIROp_CastIntToFloat},
kIROp_CastPtrToBool,
kIROp_BitCast,
+ {kIROp_CastPtrToInt, kIROp_CastIntToEnum},
+ kIROp_CastToVoid},
+ /* From Enum */
+ {kIROp_CastEnumToInt,
+ {kIROp_CastEnumToInt, kIROp_CastIntToFloat},
+ {kIROp_CastEnumToInt, kIROp_IntCast},
+ {kIROp_CastEnumToInt, kIROp_CastIntToPtr},
+ kIROp_EnumCast,
kIROp_CastToVoid},
};
@@ -4252,6 +4270,11 @@ IRInst* IRBuilder::emitVectorReshape(IRType* type, IRInst* value)
}
return emitMakeVector(targetVectorType, args);
}
+ else
+ {
+ // Sizes match, no need to reshape.
+ return value;
+ }
}
auto reshape = emitIntrinsicInst(
getVectorType(sourceVectorType->getElementType(), targetVectorType->getElementCount()),
@@ -4807,6 +4830,13 @@ IRClassType* IRBuilder::createClassType()
return classType;
}
+IREnumType* IRBuilder::createEnumType(IRType* tagType)
+{
+ IREnumType* enumType = createInst<IREnumType>(this, kIROp_EnumType, getTypeKind(), tagType);
+ addGlobalValue(this, enumType);
+ return enumType;
+}
+
IRGLSLShaderStorageBufferType* IRBuilder::createGLSLShaderStorableBufferType()
{
IRGLSLShaderStorageBufferType* ssboType = createInst<IRGLSLShaderStorageBufferType>(
@@ -8505,6 +8535,9 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options)
case kIROp_CastPtrToInt:
case kIROp_CastIntToPtr:
case kIROp_PtrCast:
+ case kIROp_CastEnumToInt:
+ case kIROp_CastIntToEnum:
+ case kIROp_EnumCast:
case kIROp_CastUInt2ToDescriptorHandle:
case kIROp_CastDescriptorHandleToUInt2:
case kIROp_CastDescriptorHandleToResource:
@@ -8956,6 +8989,9 @@ bool isMovableInst(IRInst* inst)
case kIROp_CastPtrToBool:
case kIROp_CastPtrToInt:
case kIROp_PtrCast:
+ case kIROp_CastEnumToInt:
+ case kIROp_CastIntToEnum:
+ case kIROp_EnumCast:
case kIROp_CastDynamicResource:
case kIROp_BitAnd:
case kIROp_BitNot:
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index dbf2b91be..ea784c9a6 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -2059,6 +2059,14 @@ struct IROptionalType : IRType
IRType* getValueType() { return (IRType*)getOperand(0); }
};
+/// Represents an enum type
+struct IREnumType : IRType
+{
+ IR_LEAF_ISA(EnumType)
+
+ IRType* getTagType() { return (IRType*)getOperand(0); }
+};
+
struct IRTypeType : IRType
{
IR_LEAF_ISA(TypeType);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 990090537..45efca2d9 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9228,18 +9228,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto subContext = nestedContext.getContext();
auto outerGeneric = emitOuterGenerics(subContext, decl, decl);
- // An `enum` declaration will currently lower directly to its "tag"
- // type, so that any references to the `enum` become referenes to
- // the tag type instead.
- //
// TODO: if we ever support `enum` types with payloads, we would
// need to make the `enum` lower to some kind of custom "tagged union"
// type.
IRType* loweredTagType = lowerType(subContext, decl->tagType);
+ IRType* enumType = subBuilder->createEnumType(loweredTagType);
+ addLinkageDecoration(context, enumType, decl);
- return LoweredValInfo::simple(
- finishOuterGenerics(subBuilder, loweredTagType, outerGeneric));
+ return LoweredValInfo::simple(finishOuterGenerics(subBuilder, enumType, outerGeneric));
}
LoweredValInfo visitThisTypeDecl(ThisTypeDecl* decl)
diff --git a/tests/bugs/gh-6964.slang b/tests/bugs/gh-6964.slang
new file mode 100644
index 000000000..b283a17a8
--- /dev/null
+++ b/tests/bugs/gh-6964.slang
@@ -0,0 +1,58 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -dx12
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu
+
+// CHECK: 1
+// CHECK-NEXT: 0
+// CHECK-NEXT: 2
+// CHECK-NEXT: 3
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+interface IThing
+{
+ void writeInfo(int i);
+}
+
+enum FirstEnum
+{
+ A,
+ B
+};
+
+enum SecondEnum
+{
+ C,
+ D=3
+};
+
+extension FirstEnum: IThing
+{
+ void writeInfo(int i)
+ {
+ outputBuffer[i] = 1;
+ outputBuffer[i+1] = this;
+ }
+}
+
+extension SecondEnum: IThing
+{
+ void writeInfo(int i)
+ {
+ outputBuffer[i] = 2;
+ outputBuffer[i+1] = this;
+ }
+}
+
+void indirectionFunc<T: IThing>(T val, int i)
+{
+ val.writeInfo(i);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
+{
+ indirectionFunc(FirstEnum.A, 0);
+ indirectionFunc(SecondEnum.D, 2);
+}