diff options
| -rw-r--r-- | source/slang/core.meta.slang | 9 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 27 | ||||
| -rw-r--r-- | source/slang/slang-ir-address-analysis.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-constexpr.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 51 | ||||
| -rw-r--r-- | source/slang/slang-ir-sccp.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-simplify-cfg.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa-register-allocate.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 29 | ||||
| -rw-r--r-- | tests/autodiff/modify-vector-param.slang | 35 |
21 files changed, 177 insertions, 41 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 55cf5896f..43640eb41 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2290,6 +2290,15 @@ __attributeTarget(FuncDecl) attribute_syntax [CudaKernel] : CudaKernelAttribute; __attributeTarget(FuncDecl) +attribute_syntax [CUDADeviceExport] : CudaDeviceExportAttribute; + +__attributeTarget(FuncDecl) +attribute_syntax [CUDAHost] : CudaHostAttribute; + +__attributeTarget(FuncDecl) +attribute_syntax [CUDAKernel] : CudaKernelAttribute; + +__attributeTarget(FuncDecl) attribute_syntax[AutoPyBindCUDA] : AutoPyBindCudaAttribute; __attributeTarget(AggTypeDecl) diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index a74459954..d26893987 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3321,6 +3321,20 @@ bool CLikeSourceEmitter::isTargetIntrinsic(IRInst* inst) return findTargetIntrinsicDefinition(inst, intrinsicDef); } +bool shouldWrappInExternCBlock(IRFunc* func) +{ + for (auto decor : func->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_ExternCDecoration: + case kIROp_CudaKernelDecoration: + return true; + } + } + return false; +} + void CLikeSourceEmitter::emitFunc(IRFunc* func) { // Target-intrinsic functions should never be emitted @@ -3329,6 +3343,15 @@ void CLikeSourceEmitter::emitFunc(IRFunc* func) if (isTargetIntrinsic(func)) return; + bool shouldCloseExternCBlock = shouldWrappInExternCBlock(func); + if (shouldCloseExternCBlock) + { + // If this is a C++ `extern "C"` function, then we need to emit + // it as a C function, since that is what the C++ compiler will + // expect. + // + m_writer->emit("extern \"C\" {\n"); + } if(!isDefinition(func)) { @@ -3345,6 +3368,10 @@ void CLikeSourceEmitter::emitFunc(IRFunc* func) // emitSimpleFunc(func); } + if (shouldCloseExternCBlock) + { + m_writer->emit("}\n"); + } } void CLikeSourceEmitter::emitFuncDecorationsImpl(IRFunc* func) diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp index b00b713b5..178e15292 100644 --- a/source/slang/slang-ir-address-analysis.cpp +++ b/source/slang/slang-ir-address-analysis.cpp @@ -64,7 +64,7 @@ namespace Slang } } - if (!latestOperand || as<IRParam>(latestOperand)) + if (!latestOperand || as<IRParam, IRDynamicCastBehavior::NoUnwrap>(latestOperand)) inst->insertBefore(earliestBlock->getFirstOrdinaryInst()); else inst->insertAfter(latestOperand); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index c906f93eb..10c8cdc51 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1647,7 +1647,7 @@ bool isLocalPointer(IRInst* ptrInst) // referencing something outside the function scope. // auto addr = getRootAddr(ptrInst); - return as<IRVar>(addr) || as<IRParam>(addr); + return as<IRVar>(addr) || as<IRParam, IRDynamicCastBehavior::NoUnwrap>(addr); } void lowerSwizzledStores(IRModule* module, IRFunc* func) @@ -2067,7 +2067,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam } auto primalInst = cloneInst(&cloneEnv, builder, origParam); - if (auto primalParam = as<IRParam>(primalInst)) + if (auto primalParam = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(primalInst)) { SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); primalParam->removeFromParent(); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index cb743c06a..59653c4ae 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -228,7 +228,7 @@ struct ExtractPrimalFuncContext } else { - if (as<IRParam>(inst)) + if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst)) builder.setInsertBefore(block->getFirstOrdinaryInst()); else builder.setInsertAfter(inst); @@ -427,7 +427,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( for (auto inst : instsToRemove) { - if (as<IRParam>(inst)) + if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst)) removePhiArgs(inst); inst->removeAndDeallocate(); } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 5b90e2711..a6437e3e4 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -324,7 +324,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( result = originalPairType; return result; } - if (as<IRParam>(primalType)) + if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(primalType)) { result = nullptr; return result; diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 8124f9210..b937fe052 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -471,7 +471,7 @@ public: auto block = as<IRBlock>(inst->getParent()); if (block != funcInst->getFirstBlock()) { - auto paramIndex = getParamIndexInBlock(as<IRParam>(inst)); + auto paramIndex = getParamIndexInBlock(as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst)); if (paramIndex != -1) { for (auto p : block->getPredecessors()) diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index 6e56ebd96..34b56bfef 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -172,7 +172,7 @@ IRLoop* isLoopPhi(IRParam* param) bool opCanBeConstExprByBackwardPass(IRInst* value) { if (value->getOp() == kIROp_Param) - return isLoopPhi(as<IRParam>(value)); + return isLoopPhi(as<IRParam, IRDynamicCastBehavior::NoUnwrap>(value)); return opCanBeConstExpr(value->getOp()); } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 6b5d8e59a..a087e59d7 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -766,6 +766,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// An extern_cpp decoration marks the inst to emit its name without mangling for C++ interop. INST(ExternCppDecoration, externCpp, 1, 0) + // An externC decoration marks a function should be emitted inside an extern "C" block. + INST(ExternCDecoration, externC, 0, 0) + /// An dllImport decoration marks a function as imported from a DLL. Slang will generate dynamic function loading logic to use this function at runtime. INST(DllImportDecoration, dllImport, 2, 0) /// An dllExport decoration marks a function as an export symbol. Slang will generate a native wrapper function that is exported to DLL. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 0fd48f546..04f993838 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -543,6 +543,15 @@ struct IRExternCppDecoration : IRDecoration UnownedStringSlice getName() { return getNameOperand()->getStringSlice(); } }; +struct IRExternCDecoration : IRDecoration +{ + enum + { + kOp = kIROp_ExternCDecoration + }; + IR_LEAF_ISA(ExternCDecoration) +}; + struct IRDllImportDecoration : IRDecoration { enum @@ -4276,6 +4285,11 @@ public: addDecoration(value, kIROp_ExternCppDecoration, getStringValue(mangledName)); } + void addExternCDecoration(IRInst* value) + { + addDecoration(value, kIROp_ExternCDecoration); + } + void addForceInlineDecoration(IRInst* value) { addDecoration(value, kIROp_ForceInlineDecoration); diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index e37aa322e..b8f43c5f2 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -589,7 +589,9 @@ IRGeneric* cloneGenericImpl( auto originalParam = originalGeneric->getFirstParam(); ShortList<KeyValuePair<IRInst*, IRInst*>> paramMapping; - for (; clonedParam && originalParam; (clonedParam = as<IRParam>(clonedParam->next)), (originalParam = as<IRParam>(originalParam->next))) + for (; clonedParam && originalParam; + (clonedParam = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(clonedParam->next)), + (originalParam = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(originalParam->next))) { paramMapping.add(KeyValuePair<IRInst*, IRInst*>(clonedParam, originalParam)); } diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index f6e8a3458..d344590b1 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -776,7 +776,7 @@ struct PeepholeContext : InstPassBase break; UInt paramIndex = 0; auto prevParam = inst->getPrevInst(); - while (as<IRParam>(prevParam)) + while (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(prevParam)) { prevParam = prevParam->getPrevInst(); paramIndex++; diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index c723902de..41665ddf7 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -193,7 +193,7 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst IRIntegerValue i = 0; for (auto field : structType->getFields()) { - auto tupleElement = builder.emitTargetTupleGetElement(field->getFieldType(), val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = builder.emitTargetTupleGetElement(translateToTupleType(builder, field->getFieldType()), val, builder.getIntValue(builder.getIntType(), i)); auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement); if (!convertedElement) return nullptr; @@ -315,23 +315,38 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) inst->removeAndDeallocate(); } -IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* sink = nullptr) +IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr) { - if (as<IRBasicType>(type)) + if (as<IRBasicType>(type) || as<IRVectorType>(type)) return type; switch (type->getOp()) { case kIROp_TensorViewType: return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType()); - +#if 0 + case kIROp_VectorType: + { + // Create a new struct type representing the vector. + auto hostStructType = builder->createStructType(); + const char* names[4] = { "x", "y", "z", "w" }; + for (IRIntegerValue i = 0; i < getIntVal(as<IRVectorType>(type)->getElementCount()); i++) + { + auto key = builder->createStructKey(); + if (i < 4) + builder->addNameHintDecoration(key, UnownedStringSlice(names[i])); + builder->createStructField(hostStructType, key, as<IRVectorType>(type)->getElementType()); + } + return hostStructType; + } +#endif case kIROp_StructType: { // Create a new struct type with translated fields. List<IRType*> fieldTypes; for (auto field : as<IRStructType>(type)->getFields()) { - fieldTypes.add(translateToHostType(builder, field->getFieldType())); + fieldTypes.add(translateToHostType(builder, field->getFieldType(), func)); } auto hostStructType = builder->createStructType(); @@ -348,12 +363,14 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* si } if (sink) - sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type); + sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type, func); return nullptr; } IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst) { + if (hostType == cudaType) + return inst; if (as<IRBasicType>(hostType) && as<IRBasicType>(cudaType)) return inst; @@ -361,7 +378,18 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp { case kIROp_TensorViewType: return builder->emitMakeTensorView(cudaType, inst); - +#if 0 + case kIROp_VectorType: + { + List<IRInst*> args; + auto hostStructType = cast<IRStructType>(hostType); + for (auto field : hostStructType->getFields()) + { + args.add(builder->emitFieldExtract(field->getFieldType(), inst, field->getKey())); + } + return builder->emitMakeVector(cudaType, args); + } +#endif case kIROp_StructType: { auto cudaStructType = cast<IRStructType>(cudaType); @@ -522,7 +550,7 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, DiagnosticSink* sink, IRType** outType = nullptr) { - auto type = translateToHostType(builder, param->getDataType(), sink); + auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink); if (outType) *outType = type; auto hostParam = builder->emitParam(type); @@ -859,6 +887,7 @@ void handleAutoBindNames(IRModule* module) nameBuilder << "__kernel__" << externCppHint->getName(); externCppHint->removeAndDeallocate(); builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice()); + builder.addExternCDecoration(globalInst); } } } @@ -915,7 +944,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well. if (func->findDecoration<IRCudaKernelDecoration>()) + { builder.addCudaKernelDecoration(wrapperFunc); + builder.addExternCDecoration(wrapperFunc); + } // Add an auto-pybind-cuda decoration to the wrapper function to further generate the // host-side binding for the derivative kernel. @@ -971,7 +1003,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well. if (func->findDecoration<IRCudaKernelDecoration>()) + { builder.addCudaKernelDecoration(wrapperFunc); + builder.addExternCDecoration(wrapperFunc); + } // Add an auto-pybind-cuda decoration to the wrapper function to further generate the // host-side binding for the derivative kernel. diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp index 5ae858256..d874514fe 100644 --- a/source/slang/slang-ir-sccp.cpp +++ b/source/slang/slang-ir-sccp.cpp @@ -1022,7 +1022,7 @@ struct SCCPContext // that provide arguments. We will see that logic shortly, when // handling `IRUnconditionalBranch`. // - if(as<IRParam>(inst)) + if(as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst)) return; // We want to special-case terminator instructions here, diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 44a8909e4..e00c24bdc 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -205,7 +205,7 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst) auto rootAddr = getRootAddr(addr); if (isGlobalOrUnknownMutableAddress(func, rootAddr)) return true; - if (as<IRParam>(rootAddr)) + if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(rootAddr)) return true; // If we can't find the address from our map, we conservatively assume it is an unknown address. diff --git a/source/slang/slang-ir-ssa-register-allocate.cpp b/source/slang/slang-ir-ssa-register-allocate.cpp index a93a3a8f4..ce502b454 100644 --- a/source/slang/slang-ir-ssa-register-allocate.cpp +++ b/source/slang/slang-ir-ssa-register-allocate.cpp @@ -103,7 +103,7 @@ struct RegisterAllocateContext bool isUseOfParamAfterPhiAssignment(IRDominatorTree* dom, IRUse* useToTest, IRInst* phiParam, IRInst* phiArg) { - IRParam* param = as<IRParam>(phiParam); + IRParam* param = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(phiParam); if (!param) return false; IRUse* branchUse = nullptr; diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 5ead1a1f4..4a92d263e 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -617,7 +617,7 @@ void removeLinkageDecorations(IRGlobalValueWithCode* func) void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst) { - if (as<IRParam>(inst)) + if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst)) { SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent())); auto lastParam = as<IRBlock>(inst->getParent())->getLastParam(); @@ -631,7 +631,7 @@ void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst) void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst) { - if (as<IRParam>(inst)) + if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst)) { SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent())); auto lastParam = as<IRBlock>(inst->getParent())->getLastParam(); @@ -818,7 +818,7 @@ void moveParams(IRBlock* dest, IRBlock* src) for (auto param = src->getFirstChild(); param;) { auto nextInst = param->getNextInst(); - if (as<IRDecoration>(param) || as<IRParam>(param)) + if (as<IRDecoration>(param) || as<IRParam, IRDynamicCastBehavior::NoUnwrap>(param)) { param->insertAtEnd(dest); } diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index 2973e1ee5..af793bb54 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -80,7 +80,7 @@ namespace Slang validate(context, state <= kState_AfterDecoration, child, "decorations must come before other child instructions"); state = kState_AfterDecoration; } - else if( as<IRParam>(child) ) + else if( as<IRParam, IRDynamicCastBehavior::NoUnwrap>(child) ) { validate(context, state <= kState_AfterParam, child, "parameters must come before ordinary instructions"); state = kState_AfterParam; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 95d6c9e1f..ff6c8c39e 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -304,12 +304,12 @@ namespace Slang IRParam* IRParam::getNextParam() { - return as<IRParam>(getNextInst()); + return as<IRParam, IRDynamicCastBehavior::NoUnwrap>(getNextInst()); } IRParam* IRParam::getPrevParam() { - return as<IRParam>(getPrevInst()); + return as<IRParam, IRDynamicCastBehavior::NoUnwrap>(getPrevInst()); } // IRArrayTypeBase @@ -472,7 +472,7 @@ namespace Slang // If the last instruction is a parameter, then // there are no ordinary instructions, so the last // one is a null pointer. - if (as<IRParam>(inst)) + if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst)) return nullptr; // Otherwise the last instruction is the last "ordinary" @@ -1643,8 +1643,8 @@ namespace Slang // instructions, so they need to come after // any parameters of the parent. // - while(auto param = as<IRParam>(insertBeforeInst)) - insertBeforeInst = param->getNextInst(); + while (insertBeforeInst && insertBeforeInst->getOp() == kIROp_Param) + insertBeforeInst = insertBeforeInst->getNextInst(); // For instructions that will be placed at module scope, // we don't care about relative ordering, but for everything @@ -6425,14 +6425,14 @@ namespace Slang // First walk through any `param` instructions, // so that we can format them nicely - if (auto firstParam = as<IRParam>(inst)) + if (auto firstParam = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst)) { dump(context, "(\n"); context->indent += 2; for(;;) { - auto param = as<IRParam>(inst); + auto param = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst); if (!param) break; diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index c801f156e..ef803b642 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -828,31 +828,42 @@ struct IRInst void _insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent); }; -template<typename T> +enum class IRDynamicCastBehavior +{ + Unwrap, NoUnwrap +}; + +template<typename T, IRDynamicCastBehavior behavior = IRDynamicCastBehavior::Unwrap> T* dynamicCast(IRInst* inst) { - if (inst && T::isaImpl(inst->getOp())) + if (!inst) return nullptr; + if (T::isaImpl(inst->getOp())) return static_cast<T*>(inst); + if constexpr(behavior == IRDynamicCastBehavior::Unwrap) + { + if (inst->getOp() == kIROp_AttributedType) + return dynamicCast<T>(inst->getOperand(0)); + } return nullptr; } -template<typename T> +template<typename T, IRDynamicCastBehavior behavior = IRDynamicCastBehavior::Unwrap> const T* dynamicCast(const IRInst* inst) { - return dynamicCast<T>(const_cast<IRInst*>(inst)); + return dynamicCast<T, behavior>(const_cast<IRInst*>(inst)); } // `dynamic_cast` equivalent (we just use dynamicCast) -template<typename T> +template<typename T, IRDynamicCastBehavior behavior = IRDynamicCastBehavior::Unwrap> T* as(IRInst* inst) { - return dynamicCast<T>(inst); + return dynamicCast<T, behavior>(inst); } -template<typename T> +template<typename T, IRDynamicCastBehavior behavior = IRDynamicCastBehavior::Unwrap> const T* as(const IRInst* inst) { - return dynamicCast<T>(inst); + return dynamicCast<T, behavior>(inst); } // `static_cast` equivalent, with debug validation @@ -1228,7 +1239,7 @@ struct IRBlock : IRInst // instructions at the start of the block. These play // the role of function parameters for the entry block // of a function, and of phi nodes in other blocks. - IRParam* getFirstParam() { return as<IRParam>(getFirstInst()); } + IRParam* getFirstParam() { return as<IRParam, IRDynamicCastBehavior::NoUnwrap>(getFirstInst()); } IRParam* getLastParam(); IRInstList<IRParam> getParams() { diff --git a/tests/autodiff/modify-vector-param.slang b/tests/autodiff/modify-vector-param.slang new file mode 100644 index 000000000..d6ddd7386 --- /dev/null +++ b/tests/autodiff/modify-vector-param.slang @@ -0,0 +1,35 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef float.Differential dfloat; + +struct Params +{ + float vv; +} + +[Differentiable] +float3 f(float3 x, uint2 v, Params p) +{ + v.x = 1 + int(p.vv); + + x.y = x.x * x.x; + return x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + var dpa = diffPair(float3(4.0, 2.0, 3.0)); + Params p; + p.vv = 0.0; + __bwd_diff(f)(dpa, uint2(1,2), p, float3(1.0,1.0,1.0)); + + // CHECK: 9.0 + outputBuffer[0] = dpa.d.x; // Expect: 9.0 + } +} |
