summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSruthik P <spatibandlla@nvidia.com>2025-07-17 16:47:58 +0530
committerGitHub <noreply@github.com>2025-07-17 11:17:58 +0000
commit150ec59f9081d65f523e7fe8de7a0b75c402195d (patch)
treea9753d5effb6c618870d36c6cdfdd250450476cc /source
parentab5a815297e57f579b15023cd2ebe97db6bd33eb (diff)
slang: Add support for generating getters for IR struct defs. (#7725)
This change expands the IR struct definition generation logic in slang-ir.h.lua to code generate the getters for the operands of an IR. To facilitate the above, the schema for the IR definitions in slang-ir-insts.lua is updated to allow for explicit specification of the operands of an IR, with Fiddle code generating the getters for them. slang-ir.h is updated to remove the hardcoded getters of the IRs since they are now generated by Fiddle. Fixes part of #7185
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-insts.lua54
-rw-r--r--source/slang/slang-ir.h52
-rw-r--r--source/slang/slang-ir.h.lua97
3 files changed, 118 insertions, 85 deletions
diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua
index 6e256598f..0325fcbfd 100644
--- a/source/slang/slang-ir-insts.lua
+++ b/source/slang/slang-ir-insts.lua
@@ -10,7 +10,6 @@
--
-- For a detailed description of the schema, please see docs/design/ir-instruction-definition.md
--
-
local insts = {
{ nop = {} },
-- This opcode is used as a placeholder if we were ever to deserialize a
@@ -50,12 +49,12 @@ local insts = {
},
{ CapabilitySet = { struct_name = "CapabilitySetType", hoistable = true } },
{ DynamicType = { hoistable = true } },
- { AnyValueType = { min_operands = 1, hoistable = true } },
+ { AnyValueType = { operands = { { "size" } }, hoistable = true } },
{
RawPointerTypeBase = {
hoistable = true,
{ RawPointerType = {} },
- { RTTIPointerType = { min_operands = 1 } },
+ { RTTIPointerType = { operands = { { "rTTIOperand" } } } },
{ AfterRawPointerTypeBase = {} },
},
},
@@ -68,13 +67,13 @@ local insts = {
},
{ Func = { struct_name = "FuncType", hoistable = true } },
{ BasicBlock = { struct_name = "BasicBlockType", hoistable = true } },
- { Vec = { struct_name = "VectorType", min_operands = 2, hoistable = true } },
- { Mat = { struct_name = "MatrixType", min_operands = 4, hoistable = true } },
+ { Vec = { struct_name = "VectorType", operands = { { "elementType", "IRType" }, { "elementCount" } }, hoistable = true } },
+ { Mat = { struct_name = "MatrixType", operands = { { "elementType", "IRType" }, { "rowCount" }, { "columnCount" }, { "layout" } }, hoistable = true } },
{ Conjunction = { struct_name = "ConjunctionType", hoistable = true } },
- { Attributed = { struct_name = "AttributedType", hoistable = true } },
- { Result = { struct_name = "ResultType", min_operands = 2, hoistable = true } },
- { Optional = { struct_name = "OptionalType", min_operands = 1, hoistable = true } },
- { Enum = { struct_name = "EnumType", min_operands = 1, parent = true } },
+ { Attributed = { struct_name = "AttributedType", operands = { { "baseType", "IRType" }, { "attr" } }, hoistable = true } },
+ { Result = { struct_name = "ResultType", operands = { { "valueType", "IRType" }, { "errorType", "IRType" } }, hoistable = true } },
+ { Optional = { struct_name = "OptionalType", operands = { { "valueType", "IRType" } }, hoistable = true } },
+ { Enum = { struct_name = "EnumType", operands = { { "tagType", "IRType" } }, parent = true } },
{
DifferentialPairTypeBase = {
hoistable = true,
@@ -86,14 +85,14 @@ local insts = {
{
BwdDiffIntermediateCtxType = {
struct_name = "BackwardDiffIntermediateContextType",
- min_operands = 1,
+ operands = { { "func" } },
hoistable = true,
},
},
- { TensorView = { struct_name = "TensorViewType", min_operands = 1, hoistable = true } },
+ { TensorView = { struct_name = "TensorViewType", operands = { { "elementType", "IRType" } }, hoistable = true } },
{ TorchTensor = { struct_name = "TorchTensorType", hoistable = true } },
- { ArrayListVector = { struct_name = "ArrayListType", min_operands = 1, hoistable = true } },
- { Atomic = { struct_name = "AtomicType", min_operands = 1, hoistable = true } },
+ { ArrayListVector = { struct_name = "ArrayListType", operands = { { "elementType", "IRType" } }, hoistable = true } },
+ { Atomic = { struct_name = "AtomicType", operands = { { "elementType", "IRType" } }, hoistable = true } },
{
BindExistentialsTypeBase = {
hoistable = true,
@@ -128,7 +127,7 @@ local insts = {
{ ActualGlobalRate = {} },
},
},
- { RateQualified = { struct_name = "RateQualifiedType", min_operands = 2, hoistable = true } },
+ { RateQualified = { struct_name = "RateQualifiedType", operands = { { "rate", "IRRate" }, { "valueType", "IRType" } }, hoistable = true } },
{
Kind = {
-- Kinds represent the "types of types."
@@ -169,7 +168,7 @@ local insts = {
ComPtr = {
-- A ComPtr<T> type is treated as a opaque type that represents a reference-counted handle to a COM object.
struct_name = "ComPtrType",
- min_operands = 1,
+ operands = { { "valueType", "IRType" } },
hoistable = true,
},
},
@@ -177,7 +176,7 @@ local insts = {
NativePtr = {
-- A NativePtr<T> type represents a native pointer to a managed resource.
struct_name = "NativePtrType",
- min_operands = 1,
+ operands = { { "valueType", "IRType" } },
hoistable = true,
},
},
@@ -185,7 +184,7 @@ local insts = {
DescriptorHandle = {
-- A DescriptorHandle<T> type represents a bindless handle to an opaue resource type.
struct_name = "DescriptorHandleType",
- min_operands = 1,
+ operands = { { "resourceType", "IRType" } },
hoistable = true,
},
},
@@ -207,7 +206,7 @@ local insts = {
{ Std140Layout = { struct_name = "Std140BufferLayoutType", hoistable = true } },
{ Std430Layout = { struct_name = "Std430BufferLayoutType", hoistable = true } },
{ ScalarLayout = { struct_name = "ScalarBufferLayoutType", hoistable = true } },
- { SubpassInputType = { min_operands = 2, hoistable = true } },
+ { SubpassInputType = { operands = { { "elementType", "IRType" }, { "isMultisampleInst" } }, hoistable = true } },
{ TextureFootprintType = { min_operands = 1, hoistable = true } },
{ TextureShape1DType = { hoistable = true } },
{ TextureShape2DType = { struct_name = "TextureShape2DType", hoistable = true } },
@@ -274,7 +273,7 @@ local insts = {
{ Primitives = { struct_name = "PrimitivesType", min_operands = 2 } },
},
},
- { ["metal::mesh"] = { struct_name = "MetalMeshType", min_operands = 5 } },
+ { ["metal::mesh"] = { struct_name = "MetalMeshType", operands = { { "verticesType", "IRType" }, { "primitivesType", "IRType" }, { "numVertices" }, { "numPrimitives" }, { "topology", "IRIntLit" } } } },
{ mesh_grid_properties = { struct_name = "MetalMeshGridPropertiesType" } },
{
HLSLStructuredBufferTypeBase = {
@@ -350,10 +349,10 @@ local insts = {
hoistable = true,
},
},
- { CoopVectorType = { min_operands = 2, hoistable = true } },
- { CoopMatrixType = { min_operands = 5, hoistable = true } },
+ { CoopVectorType = { operands = { { "elementType", "IRType"}, { "elementCount" } }, hoistable = true } },
+ { CoopMatrixType = { operands = { { "elementType", "IRType"}, { "scope" }, { "rowCount" }, { "columnCount" }, { "matrixUse" } }, hoistable = true } },
{
- TensorAddressingTensorLayoutType = { min_operands = 2, hoistable = true },
+ TensorAddressingTensorLayoutType = { operands = { { "dimension"}, { "clampMode" } }, hoistable = true },
},
{
TensorAddressingTensorViewType = {
@@ -410,7 +409,7 @@ local insts = {
spirvLiteralType = {
-- A type that identifies it's contained type as being emittable as `spirv_literal.
struct_name = "SPIRVLiteralType",
- min_operands = 1,
+ operands = { { "valueType", "IRType" } },
hoistable = true,
},
},
@@ -618,7 +617,7 @@ local insts = {
{ packAnyValue = { min_operands = 1 } },
{ unpackAnyValue = { min_operands = 1 } },
{ witness_table_entry = { min_operands = 2 } },
- { interface_req_entry = { struct_name = "InterfaceRequirementEntry", min_operands = 2, global = true } },
+ { interface_req_entry = { struct_name = "InterfaceRequirementEntry", operands = { { "requirementKey" }, { "requirementVal" } }, global = true } },
-- An inst to represent the workgroup size of the calling entry point.
-- We will materialize this inst during `translateGlobalVaryingVar`.
{ GetWorkGroupSize = { hoistable = true } },
@@ -974,7 +973,7 @@ local insts = {
{ loopExitValue = { min_operands = 1 } },
{
getStringHash = {
- min_operands = 1,
+ operands = { { "stringLit", "IRStringLit" } },
},
},
{ waveGetActiveMask = {} },
@@ -2267,8 +2266,8 @@ local function process(insts)
end
end
- -- If it's a leaf and doesn't have min_operands, add it
- if is_leaf(value) and value.min_operands == nil then
+ -- If it's a leaf and doesn't have min_operands and operands, add min_operands = 0
+ if is_leaf(value) and value.min_operands == nil and value.operands == nil then
value.min_operands = 0
end
@@ -2320,7 +2319,6 @@ local function process(insts)
-- Start walking from the top-level insts
walk_insts(insts)
end
-
return {
insts = insts,
stable_name_to_inst = stable_name_to_inst,
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 0281aa1c1..54fc3d8de 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1456,8 +1456,6 @@ FIDDLE()
struct IRSubpassInputType : IRType
{
FIDDLE(leafInst())
- IRType* getElementType() { return (IRType*)getOperand(0); }
- IRInst* getIsMultisampleInst() { return getOperand(1); }
bool isMultisample() { return getIntVal(getIsMultisampleInst()) == 1; }
};
@@ -1509,12 +1507,6 @@ FIDDLE()
struct IRMetalMeshType : IRType
{
FIDDLE(leafInst())
-
- IRType* getVerticesType() { return (IRType*)getOperand(0); }
- IRType* getPrimitivesType() { return (IRType*)getOperand(1); }
- IRInst* getNumVertices() { return (IRInst*)getOperand(2); }
- IRInst* getNumPrimitives() { return (IRInst*)getOperand(3); }
- IRIntLit* getTopology() { return (IRIntLit*)getOperand(4); }
};
FIDDLE()
@@ -1584,8 +1576,6 @@ FIDDLE()
struct IRAtomicType : IRType
{
FIDDLE(leafInst())
-
- IRType* getElementType() { return (IRType*)getOperand(0); }
};
@@ -1593,15 +1583,12 @@ FIDDLE()
struct IRRateQualifiedType : IRType
{
FIDDLE(leafInst())
- IRRate* getRate() { return (IRRate*)getOperand(0); }
- IRType* getValueType() { return (IRType*)getOperand(1); }
};
FIDDLE()
struct IRDescriptorHandleType : IRType
{
FIDDLE(leafInst())
- IRType* getResourceType() { return (IRType*)getOperand(0); }
};
// Unlike the AST-level type system where `TypeType` tracks the
@@ -1649,39 +1636,30 @@ FIDDLE()
struct IRBackwardDiffIntermediateContextType : IRType
{
FIDDLE(leafInst())
- IRInst* getFunc() { return getOperand(0); }
};
FIDDLE()
struct IRVectorType : IRType
{
FIDDLE(leafInst())
- IRType* getElementType() { return (IRType*)getOperand(0); }
- IRInst* getElementCount() { return getOperand(1); }
};
FIDDLE()
struct IRMatrixType : IRType
{
FIDDLE(leafInst())
- IRType* getElementType() { return (IRType*)getOperand(0); }
- IRInst* getRowCount() { return getOperand(1); }
- IRInst* getColumnCount() { return getOperand(2); }
- IRInst* getLayout() { return getOperand(3); }
};
FIDDLE()
struct IRArrayListType : IRType
{
FIDDLE(leafInst())
- IRType* getElementType() { return (IRType*)getOperand(0); }
};
FIDDLE()
struct IRTensorViewType : IRType
{
FIDDLE(leafInst())
- IRType* getElementType() { return (IRType*)getOperand(0); }
};
FIDDLE()
@@ -1694,8 +1672,6 @@ FIDDLE()
struct IRSPIRVLiteralType : IRType
{
FIDDLE(leafInst())
-
- IRType* getValueType() { return static_cast<IRType*>(getOperand(0)); }
};
FIDDLE()
@@ -1721,14 +1697,12 @@ FIDDLE()
struct IRComPtrType : public IRType
{
FIDDLE(leafInst())
- IRType* getValueType() { return (IRType*)getOperand(0); }
};
FIDDLE()
struct IRNativePtrType : public IRType
{
FIDDLE(leafInst())
- IRType* getValueType() { return (IRType*)getOperand(0); }
};
FIDDLE()
@@ -1758,7 +1732,6 @@ FIDDLE()
struct IRRTTIPointerType : IRRawPointerTypeBase
{
FIDDLE(leafInst())
- IRInst* getRTTIOperand() { return getOperand(0); }
};
FIDDLE()
@@ -1771,8 +1744,6 @@ FIDDLE()
struct IRGetStringHash : IRInst
{
FIDDLE(leafInst())
-
- IRStringLit* getStringLit() { return as<IRStringLit>(getOperand(0)); }
};
/// Get the type pointed to be `ptrType`, or `nullptr` if it is not a pointer(-like) type.
@@ -1809,27 +1780,18 @@ FIDDLE()
struct IRCoopVectorType : IRType
{
FIDDLE(leafInst())
- IRType* getElementType() { return (IRType*)getOperand(0); }
- IRInst* getElementCount() { return getOperand(1); }
};
FIDDLE()
struct IRCoopMatrixType : IRType
{
FIDDLE(leafInst())
- IRType* getElementType() { return (IRType*)getOperand(0); }
- IRInst* getScope() { return getOperand(1); }
- IRInst* getRowCount() { return getOperand(2); }
- IRInst* getColumnCount() { return getOperand(3); }
- IRInst* getMatrixUse() { return getOperand(4); }
};
FIDDLE()
struct IRTensorAddressingTensorLayoutType : IRType
{
FIDDLE(leafInst())
- IRInst* getDimension() { return getOperand(0); }
- IRInst* getClampMode() { return getOperand(1); }
};
FIDDLE()
@@ -1933,8 +1895,6 @@ FIDDLE()
struct IRInterfaceRequirementEntry : IRInst
{
FIDDLE(leafInst())
- IRInst* getRequirementKey() { return getOperand(0); }
- IRInst* getRequirementVal() { return getOperand(1); }
void setRequirementKey(IRInst* val) { setOperand(0, val); }
void setRequirementVal(IRInst* val) { setOperand(1, val); }
};
@@ -1960,9 +1920,6 @@ FIDDLE()
struct IRAttributedType : IRType
{
FIDDLE(leafInst())
-
- IRType* getBaseType() { return (IRType*)getOperand(0); }
- IRInst* getAttr() { return getOperand(1); }
};
FIDDLE()
@@ -2021,9 +1978,6 @@ FIDDLE()
struct IRResultType : IRType
{
FIDDLE(leafInst())
-
- IRType* getValueType() { return (IRType*)getOperand(0); }
- IRType* getErrorType() { return (IRType*)getOperand(1); }
};
/// Represents an `Optional<T>`.
@@ -2031,8 +1985,6 @@ FIDDLE()
struct IROptionalType : IRType
{
FIDDLE(leafInst())
-
- IRType* getValueType() { return (IRType*)getOperand(0); }
};
/// Represents an enum type
@@ -2040,8 +1992,6 @@ FIDDLE()
struct IREnumType : IRType
{
FIDDLE(leafInst())
-
- IRType* getTagType() { return (IRType*)getOperand(0); }
};
FIDDLE()
@@ -2069,7 +2019,6 @@ FIDDLE()
struct IRAnyValueType : IRType
{
FIDDLE(leafInst())
- IRInst* getSize() { return getOperand(0); }
};
FIDDLE()
@@ -2465,7 +2414,6 @@ private:
friend struct IRSerialReadContext;
friend struct IRSerialWriteContext;
friend struct Fossilized_IRModule;
-
IRModule() = delete;
/// Ctor
diff --git a/source/slang/slang-ir.h.lua b/source/slang/slang-ir.h.lua
index 4ce79455d..3f97b8f40 100644
--- a/source/slang/slang-ir.h.lua
+++ b/source/slang/slang-ir.h.lua
@@ -3,6 +3,27 @@
--
-- Helper function
+-- Find instruction data by struct name
+local function findInstData(insts, struct_name)
+ local function search(tbl)
+ for _, i in ipairs(tbl) do
+ local key, value = next(i)
+ local inst_struct_name = value.struct_name or key
+ if inst_struct_name == struct_name then
+ return value
+ end
+ -- Recursively search nested instructions
+ local result = search(value)
+ if result then
+ return result
+ end
+ end
+ return nil
+ end
+
+ return search(insts)
+end
+
-- Walk the instruction tree and call a callback for each instruction
local function walk_instructions(insts, callback, parent_struct)
local function walk_insts(tbl, parent)
@@ -28,7 +49,7 @@ end
-- The definitions for leaf instructions
local leafInst = function(name, args)
args = args or {}
- return args.noIsaImpl and ""
+ local result = args.noIsaImpl and ""
or [[static bool isaImpl(IROp op)
{
return (kIROpMask_OpMask & op) == kIROp_]]
@@ -38,12 +59,30 @@ local leafInst = function(name, args)
enum { kOp = kIROp_]]
.. name
.. [[ }; ]]
+
+ -- Add getter methods if operands are specified
+ if args.operands then
+ for i, operand in ipairs(args.operands) do
+ local operandName = operand[1]
+ local operandType = operand[2]
+ local getterName = "get" .. operandName:sub(1,1):upper() .. operandName:sub(2)
+ local returnType = "IRInst"
+ if operandType then
+ returnType = operandType
+ result = result .. "\n " .. returnType .. "* " .. getterName .. "() { return (" .. returnType .. "*)getOperand(" .. (i-1) .. "); }"
+ else
+ result = result .. "\n " .. returnType .. "* " .. getterName .. "() { return getOperand(" .. (i-1) .. "); }"
+ end
+ end
+ end
+
+ return result
end
-- The definitions for abstract instruction classes
local baseInst = function(name, args)
args = args or {}
- return args.noIsaImpl and ""
+ local result = args.noIsaImpl and ""
or [[static bool isaImpl(IROp opIn)
{
const int op = (kIROpMask_OpMask & opIn);
@@ -53,6 +92,24 @@ local baseInst = function(name, args)
.. name
.. [[;
}]]
+
+ -- Add getter methods if operands are specified
+ if args.operands then
+ for i, operand in ipairs(args.operands) do
+ local operandName = operand[1]
+ local operandType = operand[2]
+ local getterName = "get" .. operandName:sub(1,1):upper() .. operandName:sub(2)
+ local returnType = "IRInst"
+ if operandType then
+ returnType = operandType
+ result = result .. "\n " .. returnType .. "* " .. getterName .. "() { return (" .. returnType .. "*)getOperand(" .. (i-1) .. "); }"
+ else
+ result = result .. "\n " .. returnType .. "* " .. getterName .. "() { return getOperand(" .. (i-1) .. "); }"
+ end
+ end
+ end
+
+ return result
end
-- Generate struct definitions for instructions not defined by the user
@@ -124,13 +181,21 @@ local function instInfoEntries()
walk_instructions(insts, function(key, value, struct_name, parent_struct)
if value.is_leaf then
+ -- Calculate operand count from operands array or use min_operands as fallback
+ local operand_count = 0
+ if value.operands then
+ operand_count = #value.operands
+ elseif value.min_operands then
+ operand_count = value.min_operands
+ end
+
RAW(
"{kIROp_"
.. struct_name
.. ', {"'
.. value.mnemonic
.. '", '
- .. tostring(value.min_operands)
+ .. tostring(operand_count)
.. ", "
.. constructFlags(value)
.. "}},"
@@ -191,10 +256,32 @@ end
return {
leafInst = function(args)
- return leafInst(tostring(fiddle.current_decl):gsub("^IR", ""), args)
+ -- Get the current instruction definition from the Lua data
+ local insts = require("source/slang/slang-ir-insts.lua").insts
+ local current_name = tostring(fiddle.current_decl):gsub("^IR", "")
+ local inst_data = findInstData(insts, current_name)
+
+ -- Merge the args with the instruction data
+ local merged_args = args or {}
+ if inst_data and inst_data.operands then
+ merged_args.operands = inst_data.operands
+ end
+
+ return leafInst(current_name, merged_args)
end,
baseInst = function(args)
- return baseInst(tostring(fiddle.current_decl):gsub("^IR", ""), args)
+ -- Get the current instruction definition from the Lua data
+ local insts = require("source/slang/slang-ir-insts.lua").insts
+ local current_name = tostring(fiddle.current_decl):gsub("^IR", "")
+ local inst_data = findInstData(insts, current_name)
+
+ -- Merge the args with the instruction data
+ local merged_args = args or {}
+ if inst_data and inst_data.operands then
+ merged_args.operands = inst_data.operands
+ end
+
+ return baseInst(current_name, merged_args)
end,
allOtherInstStructs = allOtherInstStructs,
instStructForwardDecls = instStructForwardDecls,