diff options
| -rw-r--r-- | source/slang/slang-ir-insts.lua | 54 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 52 | ||||
| -rw-r--r-- | source/slang/slang-ir.h.lua | 97 |
3 files changed, 118 insertions, 85 deletions
diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 6e256598f..0325fcbfd 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -10,7 +10,6 @@ -- -- For a detailed description of the schema, please see docs/design/ir-instruction-definition.md -- - local insts = { { nop = {} }, -- This opcode is used as a placeholder if we were ever to deserialize a @@ -50,12 +49,12 @@ local insts = { }, { CapabilitySet = { struct_name = "CapabilitySetType", hoistable = true } }, { DynamicType = { hoistable = true } }, - { AnyValueType = { min_operands = 1, hoistable = true } }, + { AnyValueType = { operands = { { "size" } }, hoistable = true } }, { RawPointerTypeBase = { hoistable = true, { RawPointerType = {} }, - { RTTIPointerType = { min_operands = 1 } }, + { RTTIPointerType = { operands = { { "rTTIOperand" } } } }, { AfterRawPointerTypeBase = {} }, }, }, @@ -68,13 +67,13 @@ local insts = { }, { Func = { struct_name = "FuncType", hoistable = true } }, { BasicBlock = { struct_name = "BasicBlockType", hoistable = true } }, - { Vec = { struct_name = "VectorType", min_operands = 2, hoistable = true } }, - { Mat = { struct_name = "MatrixType", min_operands = 4, hoistable = true } }, + { Vec = { struct_name = "VectorType", operands = { { "elementType", "IRType" }, { "elementCount" } }, hoistable = true } }, + { Mat = { struct_name = "MatrixType", operands = { { "elementType", "IRType" }, { "rowCount" }, { "columnCount" }, { "layout" } }, hoistable = true } }, { Conjunction = { struct_name = "ConjunctionType", hoistable = true } }, - { Attributed = { struct_name = "AttributedType", hoistable = true } }, - { Result = { struct_name = "ResultType", min_operands = 2, hoistable = true } }, - { Optional = { struct_name = "OptionalType", min_operands = 1, hoistable = true } }, - { Enum = { struct_name = "EnumType", min_operands = 1, parent = true } }, + { Attributed = { struct_name = "AttributedType", operands = { { "baseType", "IRType" }, { "attr" } }, hoistable = true } }, + { Result = { struct_name = "ResultType", operands = { { "valueType", "IRType" }, { "errorType", "IRType" } }, hoistable = true } }, + { Optional = { struct_name = "OptionalType", operands = { { "valueType", "IRType" } }, hoistable = true } }, + { Enum = { struct_name = "EnumType", operands = { { "tagType", "IRType" } }, parent = true } }, { DifferentialPairTypeBase = { hoistable = true, @@ -86,14 +85,14 @@ local insts = { { BwdDiffIntermediateCtxType = { struct_name = "BackwardDiffIntermediateContextType", - min_operands = 1, + operands = { { "func" } }, hoistable = true, }, }, - { TensorView = { struct_name = "TensorViewType", min_operands = 1, hoistable = true } }, + { TensorView = { struct_name = "TensorViewType", operands = { { "elementType", "IRType" } }, hoistable = true } }, { TorchTensor = { struct_name = "TorchTensorType", hoistable = true } }, - { ArrayListVector = { struct_name = "ArrayListType", min_operands = 1, hoistable = true } }, - { Atomic = { struct_name = "AtomicType", min_operands = 1, hoistable = true } }, + { ArrayListVector = { struct_name = "ArrayListType", operands = { { "elementType", "IRType" } }, hoistable = true } }, + { Atomic = { struct_name = "AtomicType", operands = { { "elementType", "IRType" } }, hoistable = true } }, { BindExistentialsTypeBase = { hoistable = true, @@ -128,7 +127,7 @@ local insts = { { ActualGlobalRate = {} }, }, }, - { RateQualified = { struct_name = "RateQualifiedType", min_operands = 2, hoistable = true } }, + { RateQualified = { struct_name = "RateQualifiedType", operands = { { "rate", "IRRate" }, { "valueType", "IRType" } }, hoistable = true } }, { Kind = { -- Kinds represent the "types of types." @@ -169,7 +168,7 @@ local insts = { ComPtr = { -- A ComPtr<T> type is treated as a opaque type that represents a reference-counted handle to a COM object. struct_name = "ComPtrType", - min_operands = 1, + operands = { { "valueType", "IRType" } }, hoistable = true, }, }, @@ -177,7 +176,7 @@ local insts = { NativePtr = { -- A NativePtr<T> type represents a native pointer to a managed resource. struct_name = "NativePtrType", - min_operands = 1, + operands = { { "valueType", "IRType" } }, hoistable = true, }, }, @@ -185,7 +184,7 @@ local insts = { DescriptorHandle = { -- A DescriptorHandle<T> type represents a bindless handle to an opaue resource type. struct_name = "DescriptorHandleType", - min_operands = 1, + operands = { { "resourceType", "IRType" } }, hoistable = true, }, }, @@ -207,7 +206,7 @@ local insts = { { Std140Layout = { struct_name = "Std140BufferLayoutType", hoistable = true } }, { Std430Layout = { struct_name = "Std430BufferLayoutType", hoistable = true } }, { ScalarLayout = { struct_name = "ScalarBufferLayoutType", hoistable = true } }, - { SubpassInputType = { min_operands = 2, hoistable = true } }, + { SubpassInputType = { operands = { { "elementType", "IRType" }, { "isMultisampleInst" } }, hoistable = true } }, { TextureFootprintType = { min_operands = 1, hoistable = true } }, { TextureShape1DType = { hoistable = true } }, { TextureShape2DType = { struct_name = "TextureShape2DType", hoistable = true } }, @@ -274,7 +273,7 @@ local insts = { { Primitives = { struct_name = "PrimitivesType", min_operands = 2 } }, }, }, - { ["metal::mesh"] = { struct_name = "MetalMeshType", min_operands = 5 } }, + { ["metal::mesh"] = { struct_name = "MetalMeshType", operands = { { "verticesType", "IRType" }, { "primitivesType", "IRType" }, { "numVertices" }, { "numPrimitives" }, { "topology", "IRIntLit" } } } }, { mesh_grid_properties = { struct_name = "MetalMeshGridPropertiesType" } }, { HLSLStructuredBufferTypeBase = { @@ -350,10 +349,10 @@ local insts = { hoistable = true, }, }, - { CoopVectorType = { min_operands = 2, hoistable = true } }, - { CoopMatrixType = { min_operands = 5, hoistable = true } }, + { CoopVectorType = { operands = { { "elementType", "IRType"}, { "elementCount" } }, hoistable = true } }, + { CoopMatrixType = { operands = { { "elementType", "IRType"}, { "scope" }, { "rowCount" }, { "columnCount" }, { "matrixUse" } }, hoistable = true } }, { - TensorAddressingTensorLayoutType = { min_operands = 2, hoistable = true }, + TensorAddressingTensorLayoutType = { operands = { { "dimension"}, { "clampMode" } }, hoistable = true }, }, { TensorAddressingTensorViewType = { @@ -410,7 +409,7 @@ local insts = { spirvLiteralType = { -- A type that identifies it's contained type as being emittable as `spirv_literal. struct_name = "SPIRVLiteralType", - min_operands = 1, + operands = { { "valueType", "IRType" } }, hoistable = true, }, }, @@ -618,7 +617,7 @@ local insts = { { packAnyValue = { min_operands = 1 } }, { unpackAnyValue = { min_operands = 1 } }, { witness_table_entry = { min_operands = 2 } }, - { interface_req_entry = { struct_name = "InterfaceRequirementEntry", min_operands = 2, global = true } }, + { interface_req_entry = { struct_name = "InterfaceRequirementEntry", operands = { { "requirementKey" }, { "requirementVal" } }, global = true } }, -- An inst to represent the workgroup size of the calling entry point. -- We will materialize this inst during `translateGlobalVaryingVar`. { GetWorkGroupSize = { hoistable = true } }, @@ -974,7 +973,7 @@ local insts = { { loopExitValue = { min_operands = 1 } }, { getStringHash = { - min_operands = 1, + operands = { { "stringLit", "IRStringLit" } }, }, }, { waveGetActiveMask = {} }, @@ -2267,8 +2266,8 @@ local function process(insts) end end - -- If it's a leaf and doesn't have min_operands, add it - if is_leaf(value) and value.min_operands == nil then + -- If it's a leaf and doesn't have min_operands and operands, add min_operands = 0 + if is_leaf(value) and value.min_operands == nil and value.operands == nil then value.min_operands = 0 end @@ -2320,7 +2319,6 @@ local function process(insts) -- Start walking from the top-level insts walk_insts(insts) end - return { insts = insts, stable_name_to_inst = stable_name_to_inst, diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 0281aa1c1..54fc3d8de 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1456,8 +1456,6 @@ FIDDLE() struct IRSubpassInputType : IRType { FIDDLE(leafInst()) - IRType* getElementType() { return (IRType*)getOperand(0); } - IRInst* getIsMultisampleInst() { return getOperand(1); } bool isMultisample() { return getIntVal(getIsMultisampleInst()) == 1; } }; @@ -1509,12 +1507,6 @@ FIDDLE() struct IRMetalMeshType : IRType { FIDDLE(leafInst()) - - IRType* getVerticesType() { return (IRType*)getOperand(0); } - IRType* getPrimitivesType() { return (IRType*)getOperand(1); } - IRInst* getNumVertices() { return (IRInst*)getOperand(2); } - IRInst* getNumPrimitives() { return (IRInst*)getOperand(3); } - IRIntLit* getTopology() { return (IRIntLit*)getOperand(4); } }; FIDDLE() @@ -1584,8 +1576,6 @@ FIDDLE() struct IRAtomicType : IRType { FIDDLE(leafInst()) - - IRType* getElementType() { return (IRType*)getOperand(0); } }; @@ -1593,15 +1583,12 @@ FIDDLE() struct IRRateQualifiedType : IRType { FIDDLE(leafInst()) - IRRate* getRate() { return (IRRate*)getOperand(0); } - IRType* getValueType() { return (IRType*)getOperand(1); } }; FIDDLE() struct IRDescriptorHandleType : IRType { FIDDLE(leafInst()) - IRType* getResourceType() { return (IRType*)getOperand(0); } }; // Unlike the AST-level type system where `TypeType` tracks the @@ -1649,39 +1636,30 @@ FIDDLE() struct IRBackwardDiffIntermediateContextType : IRType { FIDDLE(leafInst()) - IRInst* getFunc() { return getOperand(0); } }; FIDDLE() struct IRVectorType : IRType { FIDDLE(leafInst()) - IRType* getElementType() { return (IRType*)getOperand(0); } - IRInst* getElementCount() { return getOperand(1); } }; FIDDLE() struct IRMatrixType : IRType { FIDDLE(leafInst()) - IRType* getElementType() { return (IRType*)getOperand(0); } - IRInst* getRowCount() { return getOperand(1); } - IRInst* getColumnCount() { return getOperand(2); } - IRInst* getLayout() { return getOperand(3); } }; FIDDLE() struct IRArrayListType : IRType { FIDDLE(leafInst()) - IRType* getElementType() { return (IRType*)getOperand(0); } }; FIDDLE() struct IRTensorViewType : IRType { FIDDLE(leafInst()) - IRType* getElementType() { return (IRType*)getOperand(0); } }; FIDDLE() @@ -1694,8 +1672,6 @@ FIDDLE() struct IRSPIRVLiteralType : IRType { FIDDLE(leafInst()) - - IRType* getValueType() { return static_cast<IRType*>(getOperand(0)); } }; FIDDLE() @@ -1721,14 +1697,12 @@ FIDDLE() struct IRComPtrType : public IRType { FIDDLE(leafInst()) - IRType* getValueType() { return (IRType*)getOperand(0); } }; FIDDLE() struct IRNativePtrType : public IRType { FIDDLE(leafInst()) - IRType* getValueType() { return (IRType*)getOperand(0); } }; FIDDLE() @@ -1758,7 +1732,6 @@ FIDDLE() struct IRRTTIPointerType : IRRawPointerTypeBase { FIDDLE(leafInst()) - IRInst* getRTTIOperand() { return getOperand(0); } }; FIDDLE() @@ -1771,8 +1744,6 @@ FIDDLE() struct IRGetStringHash : IRInst { FIDDLE(leafInst()) - - IRStringLit* getStringLit() { return as<IRStringLit>(getOperand(0)); } }; /// Get the type pointed to be `ptrType`, or `nullptr` if it is not a pointer(-like) type. @@ -1809,27 +1780,18 @@ FIDDLE() struct IRCoopVectorType : IRType { FIDDLE(leafInst()) - IRType* getElementType() { return (IRType*)getOperand(0); } - IRInst* getElementCount() { return getOperand(1); } }; FIDDLE() struct IRCoopMatrixType : IRType { FIDDLE(leafInst()) - IRType* getElementType() { return (IRType*)getOperand(0); } - IRInst* getScope() { return getOperand(1); } - IRInst* getRowCount() { return getOperand(2); } - IRInst* getColumnCount() { return getOperand(3); } - IRInst* getMatrixUse() { return getOperand(4); } }; FIDDLE() struct IRTensorAddressingTensorLayoutType : IRType { FIDDLE(leafInst()) - IRInst* getDimension() { return getOperand(0); } - IRInst* getClampMode() { return getOperand(1); } }; FIDDLE() @@ -1933,8 +1895,6 @@ FIDDLE() struct IRInterfaceRequirementEntry : IRInst { FIDDLE(leafInst()) - IRInst* getRequirementKey() { return getOperand(0); } - IRInst* getRequirementVal() { return getOperand(1); } void setRequirementKey(IRInst* val) { setOperand(0, val); } void setRequirementVal(IRInst* val) { setOperand(1, val); } }; @@ -1960,9 +1920,6 @@ FIDDLE() struct IRAttributedType : IRType { FIDDLE(leafInst()) - - IRType* getBaseType() { return (IRType*)getOperand(0); } - IRInst* getAttr() { return getOperand(1); } }; FIDDLE() @@ -2021,9 +1978,6 @@ FIDDLE() struct IRResultType : IRType { FIDDLE(leafInst()) - - IRType* getValueType() { return (IRType*)getOperand(0); } - IRType* getErrorType() { return (IRType*)getOperand(1); } }; /// Represents an `Optional<T>`. @@ -2031,8 +1985,6 @@ FIDDLE() struct IROptionalType : IRType { FIDDLE(leafInst()) - - IRType* getValueType() { return (IRType*)getOperand(0); } }; /// Represents an enum type @@ -2040,8 +1992,6 @@ FIDDLE() struct IREnumType : IRType { FIDDLE(leafInst()) - - IRType* getTagType() { return (IRType*)getOperand(0); } }; FIDDLE() @@ -2069,7 +2019,6 @@ FIDDLE() struct IRAnyValueType : IRType { FIDDLE(leafInst()) - IRInst* getSize() { return getOperand(0); } }; FIDDLE() @@ -2465,7 +2414,6 @@ private: friend struct IRSerialReadContext; friend struct IRSerialWriteContext; friend struct Fossilized_IRModule; - IRModule() = delete; /// Ctor diff --git a/source/slang/slang-ir.h.lua b/source/slang/slang-ir.h.lua index 4ce79455d..3f97b8f40 100644 --- a/source/slang/slang-ir.h.lua +++ b/source/slang/slang-ir.h.lua @@ -3,6 +3,27 @@ -- -- Helper function +-- Find instruction data by struct name +local function findInstData(insts, struct_name) + local function search(tbl) + for _, i in ipairs(tbl) do + local key, value = next(i) + local inst_struct_name = value.struct_name or key + if inst_struct_name == struct_name then + return value + end + -- Recursively search nested instructions + local result = search(value) + if result then + return result + end + end + return nil + end + + return search(insts) +end + -- Walk the instruction tree and call a callback for each instruction local function walk_instructions(insts, callback, parent_struct) local function walk_insts(tbl, parent) @@ -28,7 +49,7 @@ end -- The definitions for leaf instructions local leafInst = function(name, args) args = args or {} - return args.noIsaImpl and "" + local result = args.noIsaImpl and "" or [[static bool isaImpl(IROp op) { return (kIROpMask_OpMask & op) == kIROp_]] @@ -38,12 +59,30 @@ local leafInst = function(name, args) enum { kOp = kIROp_]] .. name .. [[ }; ]] + + -- Add getter methods if operands are specified + if args.operands then + for i, operand in ipairs(args.operands) do + local operandName = operand[1] + local operandType = operand[2] + local getterName = "get" .. operandName:sub(1,1):upper() .. operandName:sub(2) + local returnType = "IRInst" + if operandType then + returnType = operandType + result = result .. "\n " .. returnType .. "* " .. getterName .. "() { return (" .. returnType .. "*)getOperand(" .. (i-1) .. "); }" + else + result = result .. "\n " .. returnType .. "* " .. getterName .. "() { return getOperand(" .. (i-1) .. "); }" + end + end + end + + return result end -- The definitions for abstract instruction classes local baseInst = function(name, args) args = args or {} - return args.noIsaImpl and "" + local result = args.noIsaImpl and "" or [[static bool isaImpl(IROp opIn) { const int op = (kIROpMask_OpMask & opIn); @@ -53,6 +92,24 @@ local baseInst = function(name, args) .. name .. [[; }]] + + -- Add getter methods if operands are specified + if args.operands then + for i, operand in ipairs(args.operands) do + local operandName = operand[1] + local operandType = operand[2] + local getterName = "get" .. operandName:sub(1,1):upper() .. operandName:sub(2) + local returnType = "IRInst" + if operandType then + returnType = operandType + result = result .. "\n " .. returnType .. "* " .. getterName .. "() { return (" .. returnType .. "*)getOperand(" .. (i-1) .. "); }" + else + result = result .. "\n " .. returnType .. "* " .. getterName .. "() { return getOperand(" .. (i-1) .. "); }" + end + end + end + + return result end -- Generate struct definitions for instructions not defined by the user @@ -124,13 +181,21 @@ local function instInfoEntries() walk_instructions(insts, function(key, value, struct_name, parent_struct) if value.is_leaf then + -- Calculate operand count from operands array or use min_operands as fallback + local operand_count = 0 + if value.operands then + operand_count = #value.operands + elseif value.min_operands then + operand_count = value.min_operands + end + RAW( "{kIROp_" .. struct_name .. ', {"' .. value.mnemonic .. '", ' - .. tostring(value.min_operands) + .. tostring(operand_count) .. ", " .. constructFlags(value) .. "}}," @@ -191,10 +256,32 @@ end return { leafInst = function(args) - return leafInst(tostring(fiddle.current_decl):gsub("^IR", ""), args) + -- Get the current instruction definition from the Lua data + local insts = require("source/slang/slang-ir-insts.lua").insts + local current_name = tostring(fiddle.current_decl):gsub("^IR", "") + local inst_data = findInstData(insts, current_name) + + -- Merge the args with the instruction data + local merged_args = args or {} + if inst_data and inst_data.operands then + merged_args.operands = inst_data.operands + end + + return leafInst(current_name, merged_args) end, baseInst = function(args) - return baseInst(tostring(fiddle.current_decl):gsub("^IR", ""), args) + -- Get the current instruction definition from the Lua data + local insts = require("source/slang/slang-ir-insts.lua").insts + local current_name = tostring(fiddle.current_decl):gsub("^IR", "") + local inst_data = findInstData(insts, current_name) + + -- Merge the args with the instruction data + local merged_args = args or {} + if inst_data and inst_data.operands then + merged_args.operands = inst_data.operands + end + + return baseInst(current_name, merged_args) end, allOtherInstStructs = allOtherInstStructs, instStructForwardDecls = instStructForwardDecls, |
