summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-09-21 14:00:48 -0700
committerGitHub <noreply@github.com>2023-09-21 14:00:48 -0700
commit5b2eb06816521cc0fcfe03258452560bd200002d (patch)
treedc06cc626ff0059dded3f4245f9309b3071ae94c /source
parentaf8ce68e9fd7b6255b6e4e9e9524a285497116dc (diff)
Various slangpy fixes. (#3227)
* Make dynamic cast transparent through `IRAttributedType`. * Add [CUDAXxx] variant of attributes. * Support marshaling of vector types. * Wrap cuda kernels in `extern "C"` block. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang9
-rw-r--r--source/slang/slang-emit-c-like.cpp27
-rw-r--r--source/slang/slang-ir-address-analysis.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp4
-rw-r--r--source/slang/slang-ir-autodiff.cpp2
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp2
-rw-r--r--source/slang/slang-ir-constexpr.cpp2
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h14
-rw-r--r--source/slang/slang-ir-link.cpp4
-rw-r--r--source/slang/slang-ir-peephole.cpp2
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp51
-rw-r--r--source/slang/slang-ir-sccp.cpp2
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp2
-rw-r--r--source/slang/slang-ir-ssa-register-allocate.cpp2
-rw-r--r--source/slang/slang-ir-util.cpp6
-rw-r--r--source/slang/slang-ir-validate.cpp2
-rw-r--r--source/slang/slang-ir.cpp14
-rw-r--r--source/slang/slang-ir.h29
20 files changed, 142 insertions, 41 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 55cf5896f..43640eb41 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2290,6 +2290,15 @@ __attributeTarget(FuncDecl)
attribute_syntax [CudaKernel] : CudaKernelAttribute;
__attributeTarget(FuncDecl)
+attribute_syntax [CUDADeviceExport] : CudaDeviceExportAttribute;
+
+__attributeTarget(FuncDecl)
+attribute_syntax [CUDAHost] : CudaHostAttribute;
+
+__attributeTarget(FuncDecl)
+attribute_syntax [CUDAKernel] : CudaKernelAttribute;
+
+__attributeTarget(FuncDecl)
attribute_syntax[AutoPyBindCUDA] : AutoPyBindCudaAttribute;
__attributeTarget(AggTypeDecl)
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index a74459954..d26893987 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -3321,6 +3321,20 @@ bool CLikeSourceEmitter::isTargetIntrinsic(IRInst* inst)
return findTargetIntrinsicDefinition(inst, intrinsicDef);
}
+bool shouldWrappInExternCBlock(IRFunc* func)
+{
+ for (auto decor : func->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_ExternCDecoration:
+ case kIROp_CudaKernelDecoration:
+ return true;
+ }
+ }
+ return false;
+}
+
void CLikeSourceEmitter::emitFunc(IRFunc* func)
{
// Target-intrinsic functions should never be emitted
@@ -3329,6 +3343,15 @@ void CLikeSourceEmitter::emitFunc(IRFunc* func)
if (isTargetIntrinsic(func))
return;
+ bool shouldCloseExternCBlock = shouldWrappInExternCBlock(func);
+ if (shouldCloseExternCBlock)
+ {
+ // If this is a C++ `extern "C"` function, then we need to emit
+ // it as a C function, since that is what the C++ compiler will
+ // expect.
+ //
+ m_writer->emit("extern \"C\" {\n");
+ }
if(!isDefinition(func))
{
@@ -3345,6 +3368,10 @@ void CLikeSourceEmitter::emitFunc(IRFunc* func)
//
emitSimpleFunc(func);
}
+ if (shouldCloseExternCBlock)
+ {
+ m_writer->emit("}\n");
+ }
}
void CLikeSourceEmitter::emitFuncDecorationsImpl(IRFunc* func)
diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp
index b00b713b5..178e15292 100644
--- a/source/slang/slang-ir-address-analysis.cpp
+++ b/source/slang/slang-ir-address-analysis.cpp
@@ -64,7 +64,7 @@ namespace Slang
}
}
- if (!latestOperand || as<IRParam>(latestOperand))
+ if (!latestOperand || as<IRParam, IRDynamicCastBehavior::NoUnwrap>(latestOperand))
inst->insertBefore(earliestBlock->getFirstOrdinaryInst());
else
inst->insertAfter(latestOperand);
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index c906f93eb..10c8cdc51 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1647,7 +1647,7 @@ bool isLocalPointer(IRInst* ptrInst)
// referencing something outside the function scope.
//
auto addr = getRootAddr(ptrInst);
- return as<IRVar>(addr) || as<IRParam>(addr);
+ return as<IRVar>(addr) || as<IRParam, IRDynamicCastBehavior::NoUnwrap>(addr);
}
void lowerSwizzledStores(IRModule* module, IRFunc* func)
@@ -2067,7 +2067,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam
}
auto primalInst = cloneInst(&cloneEnv, builder, origParam);
- if (auto primalParam = as<IRParam>(primalInst))
+ if (auto primalParam = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(primalInst))
{
SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock());
primalParam->removeFromParent();
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index cb743c06a..59653c4ae 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -228,7 +228,7 @@ struct ExtractPrimalFuncContext
}
else
{
- if (as<IRParam>(inst))
+ if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst))
builder.setInsertBefore(block->getFirstOrdinaryInst());
else
builder.setInsertAfter(inst);
@@ -427,7 +427,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
for (auto inst : instsToRemove)
{
- if (as<IRParam>(inst))
+ if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst))
removePhiArgs(inst);
inst->removeAndDeallocate();
}
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 5b90e2711..a6437e3e4 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -324,7 +324,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
result = originalPairType;
return result;
}
- if (as<IRParam>(primalType))
+ if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(primalType))
{
result = nullptr;
return result;
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 8124f9210..b937fe052 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -471,7 +471,7 @@ public:
auto block = as<IRBlock>(inst->getParent());
if (block != funcInst->getFirstBlock())
{
- auto paramIndex = getParamIndexInBlock(as<IRParam>(inst));
+ auto paramIndex = getParamIndexInBlock(as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst));
if (paramIndex != -1)
{
for (auto p : block->getPredecessors())
diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp
index 6e56ebd96..34b56bfef 100644
--- a/source/slang/slang-ir-constexpr.cpp
+++ b/source/slang/slang-ir-constexpr.cpp
@@ -172,7 +172,7 @@ IRLoop* isLoopPhi(IRParam* param)
bool opCanBeConstExprByBackwardPass(IRInst* value)
{
if (value->getOp() == kIROp_Param)
- return isLoopPhi(as<IRParam>(value));
+ return isLoopPhi(as<IRParam, IRDynamicCastBehavior::NoUnwrap>(value));
return opCanBeConstExpr(value->getOp());
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 6b5d8e59a..a087e59d7 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -766,6 +766,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// An extern_cpp decoration marks the inst to emit its name without mangling for C++ interop.
INST(ExternCppDecoration, externCpp, 1, 0)
+ // An externC decoration marks a function should be emitted inside an extern "C" block.
+ INST(ExternCDecoration, externC, 0, 0)
+
/// An dllImport decoration marks a function as imported from a DLL. Slang will generate dynamic function loading logic to use this function at runtime.
INST(DllImportDecoration, dllImport, 2, 0)
/// An dllExport decoration marks a function as an export symbol. Slang will generate a native wrapper function that is exported to DLL.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 0fd48f546..04f993838 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -543,6 +543,15 @@ struct IRExternCppDecoration : IRDecoration
UnownedStringSlice getName() { return getNameOperand()->getStringSlice(); }
};
+struct IRExternCDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_ExternCDecoration
+ };
+ IR_LEAF_ISA(ExternCDecoration)
+};
+
struct IRDllImportDecoration : IRDecoration
{
enum
@@ -4276,6 +4285,11 @@ public:
addDecoration(value, kIROp_ExternCppDecoration, getStringValue(mangledName));
}
+ void addExternCDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_ExternCDecoration);
+ }
+
void addForceInlineDecoration(IRInst* value)
{
addDecoration(value, kIROp_ForceInlineDecoration);
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index e37aa322e..b8f43c5f2 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -589,7 +589,9 @@ IRGeneric* cloneGenericImpl(
auto originalParam = originalGeneric->getFirstParam();
ShortList<KeyValuePair<IRInst*, IRInst*>> paramMapping;
- for (; clonedParam && originalParam; (clonedParam = as<IRParam>(clonedParam->next)), (originalParam = as<IRParam>(originalParam->next)))
+ for (; clonedParam && originalParam;
+ (clonedParam = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(clonedParam->next)),
+ (originalParam = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(originalParam->next)))
{
paramMapping.add(KeyValuePair<IRInst*, IRInst*>(clonedParam, originalParam));
}
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index f6e8a3458..d344590b1 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -776,7 +776,7 @@ struct PeepholeContext : InstPassBase
break;
UInt paramIndex = 0;
auto prevParam = inst->getPrevInst();
- while (as<IRParam>(prevParam))
+ while (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(prevParam))
{
prevParam = prevParam->getPrevInst();
paramIndex++;
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index c723902de..41665ddf7 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -193,7 +193,7 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
IRIntegerValue i = 0;
for (auto field : structType->getFields())
{
- auto tupleElement = builder.emitTargetTupleGetElement(field->getFieldType(), val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = builder.emitTargetTupleGetElement(translateToTupleType(builder, field->getFieldType()), val, builder.getIntValue(builder.getIntType(), i));
auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement);
if (!convertedElement)
return nullptr;
@@ -315,23 +315,38 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
inst->removeAndDeallocate();
}
-IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* sink = nullptr)
+IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr)
{
- if (as<IRBasicType>(type))
+ if (as<IRBasicType>(type) || as<IRVectorType>(type))
return type;
switch (type->getOp())
{
case kIROp_TensorViewType:
return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType());
-
+#if 0
+ case kIROp_VectorType:
+ {
+ // Create a new struct type representing the vector.
+ auto hostStructType = builder->createStructType();
+ const char* names[4] = { "x", "y", "z", "w" };
+ for (IRIntegerValue i = 0; i < getIntVal(as<IRVectorType>(type)->getElementCount()); i++)
+ {
+ auto key = builder->createStructKey();
+ if (i < 4)
+ builder->addNameHintDecoration(key, UnownedStringSlice(names[i]));
+ builder->createStructField(hostStructType, key, as<IRVectorType>(type)->getElementType());
+ }
+ return hostStructType;
+ }
+#endif
case kIROp_StructType:
{
// Create a new struct type with translated fields.
List<IRType*> fieldTypes;
for (auto field : as<IRStructType>(type)->getFields())
{
- fieldTypes.add(translateToHostType(builder, field->getFieldType()));
+ fieldTypes.add(translateToHostType(builder, field->getFieldType(), func));
}
auto hostStructType = builder->createStructType();
@@ -348,12 +363,14 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* si
}
if (sink)
- sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type);
+ sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type, func);
return nullptr;
}
IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst)
{
+ if (hostType == cudaType)
+ return inst;
if (as<IRBasicType>(hostType) && as<IRBasicType>(cudaType))
return inst;
@@ -361,7 +378,18 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
{
case kIROp_TensorViewType:
return builder->emitMakeTensorView(cudaType, inst);
-
+#if 0
+ case kIROp_VectorType:
+ {
+ List<IRInst*> args;
+ auto hostStructType = cast<IRStructType>(hostType);
+ for (auto field : hostStructType->getFields())
+ {
+ args.add(builder->emitFieldExtract(field->getFieldType(), inst, field->getKey()));
+ }
+ return builder->emitMakeVector(cudaType, args);
+ }
+#endif
case kIROp_StructType:
{
auto cudaStructType = cast<IRStructType>(cudaType);
@@ -522,7 +550,7 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host
IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, DiagnosticSink* sink, IRType** outType = nullptr)
{
- auto type = translateToHostType(builder, param->getDataType(), sink);
+ auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink);
if (outType)
*outType = type;
auto hostParam = builder->emitParam(type);
@@ -859,6 +887,7 @@ void handleAutoBindNames(IRModule* module)
nameBuilder << "__kernel__" << externCppHint->getName();
externCppHint->removeAndDeallocate();
builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice());
+ builder.addExternCDecoration(globalInst);
}
}
}
@@ -915,7 +944,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
// If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well.
if (func->findDecoration<IRCudaKernelDecoration>())
+ {
builder.addCudaKernelDecoration(wrapperFunc);
+ builder.addExternCDecoration(wrapperFunc);
+ }
// Add an auto-pybind-cuda decoration to the wrapper function to further generate the
// host-side binding for the derivative kernel.
@@ -971,7 +1003,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
// If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well.
if (func->findDecoration<IRCudaKernelDecoration>())
+ {
builder.addCudaKernelDecoration(wrapperFunc);
+ builder.addExternCDecoration(wrapperFunc);
+ }
// Add an auto-pybind-cuda decoration to the wrapper function to further generate the
// host-side binding for the derivative kernel.
diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp
index 5ae858256..d874514fe 100644
--- a/source/slang/slang-ir-sccp.cpp
+++ b/source/slang/slang-ir-sccp.cpp
@@ -1022,7 +1022,7 @@ struct SCCPContext
// that provide arguments. We will see that logic shortly, when
// handling `IRUnconditionalBranch`.
//
- if(as<IRParam>(inst))
+ if(as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst))
return;
// We want to special-case terminator instructions here,
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index 44a8909e4..e00c24bdc 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -205,7 +205,7 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst)
auto rootAddr = getRootAddr(addr);
if (isGlobalOrUnknownMutableAddress(func, rootAddr))
return true;
- if (as<IRParam>(rootAddr))
+ if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(rootAddr))
return true;
// If we can't find the address from our map, we conservatively assume it is an unknown address.
diff --git a/source/slang/slang-ir-ssa-register-allocate.cpp b/source/slang/slang-ir-ssa-register-allocate.cpp
index a93a3a8f4..ce502b454 100644
--- a/source/slang/slang-ir-ssa-register-allocate.cpp
+++ b/source/slang/slang-ir-ssa-register-allocate.cpp
@@ -103,7 +103,7 @@ struct RegisterAllocateContext
bool isUseOfParamAfterPhiAssignment(IRDominatorTree* dom, IRUse* useToTest, IRInst* phiParam, IRInst* phiArg)
{
- IRParam* param = as<IRParam>(phiParam);
+ IRParam* param = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(phiParam);
if (!param)
return false;
IRUse* branchUse = nullptr;
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 5ead1a1f4..4a92d263e 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -617,7 +617,7 @@ void removeLinkageDecorations(IRGlobalValueWithCode* func)
void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst)
{
- if (as<IRParam>(inst))
+ if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst))
{
SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent()));
auto lastParam = as<IRBlock>(inst->getParent())->getLastParam();
@@ -631,7 +631,7 @@ void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst)
void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst)
{
- if (as<IRParam>(inst))
+ if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst))
{
SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent()));
auto lastParam = as<IRBlock>(inst->getParent())->getLastParam();
@@ -818,7 +818,7 @@ void moveParams(IRBlock* dest, IRBlock* src)
for (auto param = src->getFirstChild(); param;)
{
auto nextInst = param->getNextInst();
- if (as<IRDecoration>(param) || as<IRParam>(param))
+ if (as<IRDecoration>(param) || as<IRParam, IRDynamicCastBehavior::NoUnwrap>(param))
{
param->insertAtEnd(dest);
}
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index 2973e1ee5..af793bb54 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -80,7 +80,7 @@ namespace Slang
validate(context, state <= kState_AfterDecoration, child, "decorations must come before other child instructions");
state = kState_AfterDecoration;
}
- else if( as<IRParam>(child) )
+ else if( as<IRParam, IRDynamicCastBehavior::NoUnwrap>(child) )
{
validate(context, state <= kState_AfterParam, child, "parameters must come before ordinary instructions");
state = kState_AfterParam;
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 95d6c9e1f..ff6c8c39e 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -304,12 +304,12 @@ namespace Slang
IRParam* IRParam::getNextParam()
{
- return as<IRParam>(getNextInst());
+ return as<IRParam, IRDynamicCastBehavior::NoUnwrap>(getNextInst());
}
IRParam* IRParam::getPrevParam()
{
- return as<IRParam>(getPrevInst());
+ return as<IRParam, IRDynamicCastBehavior::NoUnwrap>(getPrevInst());
}
// IRArrayTypeBase
@@ -472,7 +472,7 @@ namespace Slang
// If the last instruction is a parameter, then
// there are no ordinary instructions, so the last
// one is a null pointer.
- if (as<IRParam>(inst))
+ if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst))
return nullptr;
// Otherwise the last instruction is the last "ordinary"
@@ -1643,8 +1643,8 @@ namespace Slang
// instructions, so they need to come after
// any parameters of the parent.
//
- while(auto param = as<IRParam>(insertBeforeInst))
- insertBeforeInst = param->getNextInst();
+ while (insertBeforeInst && insertBeforeInst->getOp() == kIROp_Param)
+ insertBeforeInst = insertBeforeInst->getNextInst();
// For instructions that will be placed at module scope,
// we don't care about relative ordering, but for everything
@@ -6425,14 +6425,14 @@ namespace Slang
// First walk through any `param` instructions,
// so that we can format them nicely
- if (auto firstParam = as<IRParam>(inst))
+ if (auto firstParam = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst))
{
dump(context, "(\n");
context->indent += 2;
for(;;)
{
- auto param = as<IRParam>(inst);
+ auto param = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst);
if (!param)
break;
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index c801f156e..ef803b642 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -828,31 +828,42 @@ struct IRInst
void _insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent);
};
-template<typename T>
+enum class IRDynamicCastBehavior
+{
+ Unwrap, NoUnwrap
+};
+
+template<typename T, IRDynamicCastBehavior behavior = IRDynamicCastBehavior::Unwrap>
T* dynamicCast(IRInst* inst)
{
- if (inst && T::isaImpl(inst->getOp()))
+ if (!inst) return nullptr;
+ if (T::isaImpl(inst->getOp()))
return static_cast<T*>(inst);
+ if constexpr(behavior == IRDynamicCastBehavior::Unwrap)
+ {
+ if (inst->getOp() == kIROp_AttributedType)
+ return dynamicCast<T>(inst->getOperand(0));
+ }
return nullptr;
}
-template<typename T>
+template<typename T, IRDynamicCastBehavior behavior = IRDynamicCastBehavior::Unwrap>
const T* dynamicCast(const IRInst* inst)
{
- return dynamicCast<T>(const_cast<IRInst*>(inst));
+ return dynamicCast<T, behavior>(const_cast<IRInst*>(inst));
}
// `dynamic_cast` equivalent (we just use dynamicCast)
-template<typename T>
+template<typename T, IRDynamicCastBehavior behavior = IRDynamicCastBehavior::Unwrap>
T* as(IRInst* inst)
{
- return dynamicCast<T>(inst);
+ return dynamicCast<T, behavior>(inst);
}
-template<typename T>
+template<typename T, IRDynamicCastBehavior behavior = IRDynamicCastBehavior::Unwrap>
const T* as(const IRInst* inst)
{
- return dynamicCast<T>(inst);
+ return dynamicCast<T, behavior>(inst);
}
// `static_cast` equivalent, with debug validation
@@ -1228,7 +1239,7 @@ struct IRBlock : IRInst
// instructions at the start of the block. These play
// the role of function parameters for the entry block
// of a function, and of phi nodes in other blocks.
- IRParam* getFirstParam() { return as<IRParam>(getFirstInst()); }
+ IRParam* getFirstParam() { return as<IRParam, IRDynamicCastBehavior::NoUnwrap>(getFirstInst()); }
IRParam* getLastParam();
IRInstList<IRParam> getParams()
{