diff options
| author | Julius Ikkala <julius.ikkala@gmail.com> | 2025-05-03 23:27:03 +0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-03 23:27:03 +0300 |
| commit | 6f6103c4dbc77d5bceae7c8e766ec3cabc293364 (patch) | |
| tree | 5d00a97065771ff3dd95b837e6e005512797487e /source | |
| parent | 7f9283a34b4aaf3401cdb652a2f9208b2b4ff4f4 (diff) | |
Add IREnumType to distinguish enums from ints and each other (#6973)
* Add IREnumType to distinguish enums from ints and each other
* Add issue example as test
* format code
* Add expected test output
* Fix peephole optimization hanging
No idea why this PR triggered this, but there seems to have been a clear bug
here anyway, so may just as well fix it now.
* Move enum lowering later
* Add linkage decoration to enum type
* Use filecheck-buffer instead of expected.txt
* Fix comment
* Make enum casts actually use IR enum casts
They were all BuiltinCasts by accident
* Lower enum type before VM
* Deal with rate-qualified types in enum cast
* Allow any value marshalling for enum types
* Handle new enum instructions in a couple more switches
* Fix formatting
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-any-value-marshalling.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-layout.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-enum-type.cpp | 149 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-enum-type.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 42 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 9 |
12 files changed, 259 insertions, 9 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index aa7387e22..ac6336e9c 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -72,6 +72,7 @@ #include "slang-ir-lower-combined-texture-sampler.h" #include "slang-ir-lower-coopvec.h" #include "slang-ir-lower-dynamic-resource-heap.h" +#include "slang-ir-lower-enum-type.h" #include "slang-ir-lower-generics.h" #include "slang-ir-lower-glsl-ssbo-types.h" #include "slang-ir-lower-l-value-cast.h" @@ -312,6 +313,7 @@ struct RequiredLoweringPassSet bool debugInfo; bool resultType; bool optionalType; + bool enumType; bool combinedTextureSamplers; bool reinterpret; bool generics; @@ -356,6 +358,9 @@ void calcRequiredLoweringPassSet( case kIROp_OptionalType: result.optionalType = true; break; + case kIROp_EnumType: + result.enumType = true; + break; case kIROp_TextureType: if (!isKhronosTarget(codeGenContext->getTargetReq())) { @@ -1159,6 +1164,9 @@ Result linkAndOptimizeIR( cleanupGenerics(targetProgram, irModule, sink); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS"); + if (requiredLoweringPassSet.enumType) + lowerEnumType(irModule, sink); + // Don't need to run any further target-dependent passes if we are generating code // for host vm. if (target == CodeGenTarget::HostVM) diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index d124b293d..b3bcf3316 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -155,6 +155,13 @@ struct AnyValueMarshallingContext case kIROp_PtrType: context->marshalBasicType(builder, dataType, concreteTypedVar); break; + case kIROp_EnumType: + { + auto enumType = static_cast<IREnumType*>(dataType); + auto tagType = enumType->getTagType(); + context->marshalBasicType(builder, tagType, concreteTypedVar); + break; + } case kIROp_VectorType: { auto vectorType = static_cast<IRVectorType*>(dataType); @@ -868,6 +875,12 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) case kIROp_UInt8Type: case kIROp_Int8Type: return offset + 1; + case kIROp_EnumType: + { + auto enumType = static_cast<IREnumType*>(type); + auto tagType = enumType->getTagType(); + return _getAnyValueSizeRaw(tagType, offset); + } case kIROp_VectorType: { auto vectorType = static_cast<IRVectorType*>(type); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 08ca56f7d..febd30d7e 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -2157,6 +2157,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_GetNativePtr: case kIROp_CastIntToFloat: case kIROp_CastFloatToInt: + case kIROp_CastIntToEnum: + case kIROp_CastEnumToInt: + case kIROp_EnumCast: case kIROp_DetachDerivative: case kIROp_GetSequentialID: case kIROp_GetStringHash: diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 8fb1bc9ac..66505dbd0 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -60,6 +60,7 @@ INST(Nop, nop, 0, 0) INST(AttributedType, Attributed, 0, HOISTABLE) INST(ResultType, Result, 2, HOISTABLE) INST(OptionalType, Optional, 1, HOISTABLE) + INST(EnumType, Enum, 1, PARENT) INST(DifferentialPairType, DiffPair, 1, HOISTABLE) INST(DifferentialPairUserCodeType, DiffPairUserCode, 1, HOISTABLE) @@ -1232,6 +1233,9 @@ INST(CastPtrToInt, CastPtrToInt, 1, 0) INST(CastIntToPtr, CastIntToPtr, 1, 0) INST(CastToVoid, castToVoid, 1, 0) INST(PtrCast, PtrCast, 1, 0) +INST(CastEnumToInt, CastEnumToInt, 1, 0) +INST(CastIntToEnum, CastIntToEnum, 1, 0) +INST(EnumCast, EnumCast, 1, 0) INST(CastUInt2ToDescriptorHandle, CastUInt2ToDescriptorHandle, 1, 0) INST(CastDescriptorHandleToUInt2, CastDescriptorHandleToUInt2, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 21d5e1c23..5b196cf24 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4299,6 +4299,9 @@ public: // Create an initially empty `class` type. IRClassType* createClassType(); + // Create an an `enum` type with the given tag type. + IREnumType* createEnumType(IRType* tagType); + // Create an initially empty `GLSLShaderStorageBufferType` type. IRGLSLShaderStorageBufferType* createGLSLShaderStorableBufferType(); IRGLSLShaderStorageBufferType* createGLSLShaderStorableBufferType( diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index df643a3c1..7b3aa4340 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -377,6 +377,13 @@ static Result _calcSizeAndAlignment( attributedType->getBaseType(), outSizeAndAlignment); } + case kIROp_EnumType: + { + auto enumType = cast<IREnumType>(type); + auto tagType = enumType->getTagType(); + return _calcSizeAndAlignment(optionSet, rules, tagType, outSizeAndAlignment); + } + break; default: break; } diff --git a/source/slang/slang-ir-lower-enum-type.cpp b/source/slang/slang-ir-lower-enum-type.cpp new file mode 100644 index 000000000..548c29a51 --- /dev/null +++ b/source/slang/slang-ir-lower-enum-type.cpp @@ -0,0 +1,149 @@ +// slang-ir-lower-enum-type.cpp + +#include "slang-ir-lower-enum-type.h" + +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + +namespace Slang +{ +struct EnumTypeLoweringContext +{ + IRModule* module; + DiagnosticSink* sink; + + InstWorkList workList; + InstHashSet workListSet; + + IRGeneric* genericOptionalStructType = nullptr; + IRStructKey* valueKey = nullptr; + IRStructKey* hasValueKey = nullptr; + + EnumTypeLoweringContext(IRModule* inModule) + : module(inModule), workList(inModule), workListSet(inModule) + { + } + + struct LoweredEnumTypeInfo : public RefObject + { + IRType* enumType = nullptr; + IRType* loweredType = nullptr; + }; + Dictionary<IRInst*, RefPtr<LoweredEnumTypeInfo>> loweredEnumTypes; + + void addToWorkList(IRInst* inst) + { + if (workListSet.contains(inst)) + return; + + workList.add(inst); + workListSet.add(inst); + } + + LoweredEnumTypeInfo* getLoweredEnumType(IRInst* type) + { + if (auto loweredInfo = loweredEnumTypes.tryGetValue(type)) + return loweredInfo->Ptr(); + + if (!type) + return nullptr; + + if (type->getOp() != kIROp_EnumType) + return nullptr; + + RefPtr<LoweredEnumTypeInfo> info = new LoweredEnumTypeInfo(); + auto enumType = cast<IREnumType>(type); + auto valueType = enumType->getTagType(); + info->enumType = (IRType*)type; + info->loweredType = valueType; + loweredEnumTypes[type] = info; + return info.Ptr(); + } + + void processEnumType(IREnumType* inst) + { + auto loweredEnumTypeInfo = getLoweredEnumType(inst); + SLANG_ASSERT(loweredEnumTypeInfo); + SLANG_UNUSED(loweredEnumTypeInfo); + } + + void processEnumCast(IRInst* inst) + { + IRBuilder builderStorage(module); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto value = inst->getOperand(0); + if (auto enumType = getLoweredEnumType(value->getDataType())) + { + auto rate = value->getRate(); + auto type = enumType->loweredType; + if (rate) + { + type = builder->getRateQualifiedType(rate, type); + } + + value->setFullType(type); + } + + auto type = inst->getDataType(); + if (auto enumType = getLoweredEnumType(type)) + { // Cast was into enum, so use tag type instead. + type = enumType->loweredType; + } + + auto cast = builder->emitCast(type, value); + + inst->replaceUsesWith(cast); + inst->removeAndDeallocate(); + } + + void processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_EnumType: + processEnumType((IREnumType*)inst); + break; + case kIROp_CastEnumToInt: + case kIROp_CastIntToEnum: + case kIROp_EnumCast: + processEnumCast(inst); + break; + default: + break; + } + } + + void processModule() + { + addToWorkList(module->getModuleInst()); + + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + workList.removeLast(); + workListSet.remove(inst); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + + // Replace all enum types with their lowered equivalent types. + for (const auto& [key, value] : loweredEnumTypes) + key->replaceUsesWith(value->loweredType); + } +}; + +void lowerEnumType(IRModule* module, DiagnosticSink* sink) +{ + EnumTypeLoweringContext context(module); + context.sink = sink; + context.processModule(); +} +} // namespace Slang diff --git a/source/slang/slang-ir-lower-enum-type.h b/source/slang/slang-ir-lower-enum-type.h new file mode 100644 index 000000000..37a465bcb --- /dev/null +++ b/source/slang/slang-ir-lower-enum-type.h @@ -0,0 +1,14 @@ +// slang-ir-lower-enum-type.h +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +struct IRModule; +class DiagnosticSink; + +/// Lower `IREnumType` to their underlying integer types. +void lowerEnumType(IRModule* module, DiagnosticSink* sink); + +} // namespace Slang diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 36aec22b5..5aae53747 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -106,6 +106,7 @@ IROp getTypeStyle(IROp op) { case kIROp_VoidType: case kIROp_BoolType: + case kIROp_EnumType: { return op; } @@ -220,6 +221,7 @@ bool isValueType(IRInst* dataType) case kIROp_FuncType: case kIROp_RaytracingAccelerationStructureType: case kIROp_GLSLAtomicUintType: + case kIROp_EnumType: return true; default: // Read-only resource handles are considered as Value type. @@ -271,6 +273,12 @@ bool isSimpleDataType(IRType* type) case kIROp_AnyValueType: case kIROp_PtrType: return true; + case kIROp_EnumType: + { + auto enumType = as<IREnumType>(type); + auto tagType = enumType->getTagType(); + return isSimpleDataType(tagType); + } case kIROp_ArrayType: case kIROp_UnsizedArrayType: return isSimpleDataType((IRType*)type->getOperand(0)); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index c105a698a..c1cec36fe 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4051,6 +4051,7 @@ enum class TypeCastStyle Float, Bool, Ptr, + Enum, Void }; static TypeCastStyle _getTypeStyleId(IRType* type) @@ -4083,6 +4084,8 @@ static TypeCastStyle _getTypeStyleId(IRType* type) case kIROp_RefType: case kIROp_ConstRefType: return TypeCastStyle::Ptr; + case kIROp_EnumType: + return TypeCastStyle::Enum; case kIROp_VoidType: return TypeCastStyle::Void; default: @@ -4131,27 +4134,42 @@ IRInst* IRBuilder::emitCast(IRType* type, IRInst* value, bool fallbackToBuiltinC } }; - static const OpSeq opMap[4][5] = { - /* To: Int, Float, Bool, Ptr, Void*/ + static const OpSeq opMap[5][6] = { + /* To: Int, Float, Bool, Ptr, Enum, Void */ /* From Int */ { kIROp_IntCast, kIROp_CastIntToFloat, kIROp_IntCast, kIROp_CastIntToPtr, + kIROp_CastIntToEnum, kIROp_CastToVoid}, /* From Float */ {kIROp_CastFloatToInt, kIROp_FloatCast, {kIROp_Neq}, {kIROp_CastFloatToInt, kIROp_CastIntToPtr}, + {kIROp_CastFloatToInt, kIROp_CastIntToEnum}, kIROp_CastToVoid}, /* From Bool */ - {kIROp_IntCast, kIROp_CastIntToFloat, kIROp_Nop, kIROp_CastIntToPtr, kIROp_CastToVoid}, + {kIROp_IntCast, + kIROp_CastIntToFloat, + kIROp_Nop, + kIROp_CastIntToPtr, + kIROp_CastIntToEnum, + kIROp_CastToVoid}, /* From Ptr */ {kIROp_CastPtrToInt, {kIROp_CastPtrToInt, kIROp_CastIntToFloat}, kIROp_CastPtrToBool, kIROp_BitCast, + {kIROp_CastPtrToInt, kIROp_CastIntToEnum}, + kIROp_CastToVoid}, + /* From Enum */ + {kIROp_CastEnumToInt, + {kIROp_CastEnumToInt, kIROp_CastIntToFloat}, + {kIROp_CastEnumToInt, kIROp_IntCast}, + {kIROp_CastEnumToInt, kIROp_CastIntToPtr}, + kIROp_EnumCast, kIROp_CastToVoid}, }; @@ -4252,6 +4270,11 @@ IRInst* IRBuilder::emitVectorReshape(IRType* type, IRInst* value) } return emitMakeVector(targetVectorType, args); } + else + { + // Sizes match, no need to reshape. + return value; + } } auto reshape = emitIntrinsicInst( getVectorType(sourceVectorType->getElementType(), targetVectorType->getElementCount()), @@ -4807,6 +4830,13 @@ IRClassType* IRBuilder::createClassType() return classType; } +IREnumType* IRBuilder::createEnumType(IRType* tagType) +{ + IREnumType* enumType = createInst<IREnumType>(this, kIROp_EnumType, getTypeKind(), tagType); + addGlobalValue(this, enumType); + return enumType; +} + IRGLSLShaderStorageBufferType* IRBuilder::createGLSLShaderStorableBufferType() { IRGLSLShaderStorageBufferType* ssboType = createInst<IRGLSLShaderStorageBufferType>( @@ -8505,6 +8535,9 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_CastPtrToInt: case kIROp_CastIntToPtr: case kIROp_PtrCast: + case kIROp_CastEnumToInt: + case kIROp_CastIntToEnum: + case kIROp_EnumCast: case kIROp_CastUInt2ToDescriptorHandle: case kIROp_CastDescriptorHandleToUInt2: case kIROp_CastDescriptorHandleToResource: @@ -8956,6 +8989,9 @@ bool isMovableInst(IRInst* inst) case kIROp_CastPtrToBool: case kIROp_CastPtrToInt: case kIROp_PtrCast: + case kIROp_CastEnumToInt: + case kIROp_CastIntToEnum: + case kIROp_EnumCast: case kIROp_CastDynamicResource: case kIROp_BitAnd: case kIROp_BitNot: diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index dbf2b91be..ea784c9a6 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -2059,6 +2059,14 @@ struct IROptionalType : IRType IRType* getValueType() { return (IRType*)getOperand(0); } }; +/// Represents an enum type +struct IREnumType : IRType +{ + IR_LEAF_ISA(EnumType) + + IRType* getTagType() { return (IRType*)getOperand(0); } +}; + struct IRTypeType : IRType { IR_LEAF_ISA(TypeType); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 990090537..45efca2d9 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -9228,18 +9228,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto subContext = nestedContext.getContext(); auto outerGeneric = emitOuterGenerics(subContext, decl, decl); - // An `enum` declaration will currently lower directly to its "tag" - // type, so that any references to the `enum` become referenes to - // the tag type instead. - // // TODO: if we ever support `enum` types with payloads, we would // need to make the `enum` lower to some kind of custom "tagged union" // type. IRType* loweredTagType = lowerType(subContext, decl->tagType); + IRType* enumType = subBuilder->createEnumType(loweredTagType); + addLinkageDecoration(context, enumType, decl); - return LoweredValInfo::simple( - finishOuterGenerics(subBuilder, loweredTagType, outerGeneric)); + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, enumType, outerGeneric)); } LoweredValInfo visitThisTypeDecl(ThisTypeDecl* decl) |
