From 5b2eb06816521cc0fcfe03258452560bd200002d Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 21 Sep 2023 14:00:48 -0700 Subject: Various slangpy fixes. (#3227) * Make dynamic cast transparent through `IRAttributedType`. * Add [CUDAXxx] variant of attributes. * Support marshaling of vector types. * Wrap cuda kernels in `extern "C"` block. --------- Co-authored-by: Yong He --- source/slang/core.meta.slang | 9 ++++ source/slang/slang-emit-c-like.cpp | 27 ++++++++++++ source/slang/slang-ir-address-analysis.cpp | 2 +- source/slang/slang-ir-autodiff-fwd.cpp | 4 +- source/slang/slang-ir-autodiff-unzip.cpp | 4 +- source/slang/slang-ir-autodiff.cpp | 2 +- source/slang/slang-ir-check-differentiability.cpp | 2 +- source/slang/slang-ir-constexpr.cpp | 2 +- source/slang/slang-ir-inst-defs.h | 3 ++ source/slang/slang-ir-insts.h | 14 +++++++ source/slang/slang-ir-link.cpp | 4 +- source/slang/slang-ir-peephole.cpp | 2 +- source/slang/slang-ir-pytorch-cpp-binding.cpp | 51 +++++++++++++++++++---- source/slang/slang-ir-sccp.cpp | 2 +- source/slang/slang-ir-simplify-cfg.cpp | 2 +- source/slang/slang-ir-ssa-register-allocate.cpp | 2 +- source/slang/slang-ir-util.cpp | 6 +-- source/slang/slang-ir-validate.cpp | 2 +- source/slang/slang-ir.cpp | 14 +++---- source/slang/slang-ir.h | 29 +++++++++---- 20 files changed, 142 insertions(+), 41 deletions(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 55cf5896f..43640eb41 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2289,6 +2289,15 @@ attribute_syntax [CudaHost] : CudaHostAttribute; __attributeTarget(FuncDecl) attribute_syntax [CudaKernel] : CudaKernelAttribute; +__attributeTarget(FuncDecl) +attribute_syntax [CUDADeviceExport] : CudaDeviceExportAttribute; + +__attributeTarget(FuncDecl) +attribute_syntax [CUDAHost] : CudaHostAttribute; + +__attributeTarget(FuncDecl) +attribute_syntax [CUDAKernel] : CudaKernelAttribute; + __attributeTarget(FuncDecl) attribute_syntax[AutoPyBindCUDA] : AutoPyBindCudaAttribute; diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index a74459954..d26893987 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3321,6 +3321,20 @@ bool CLikeSourceEmitter::isTargetIntrinsic(IRInst* inst) return findTargetIntrinsicDefinition(inst, intrinsicDef); } +bool shouldWrappInExternCBlock(IRFunc* func) +{ + for (auto decor : func->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_ExternCDecoration: + case kIROp_CudaKernelDecoration: + return true; + } + } + return false; +} + void CLikeSourceEmitter::emitFunc(IRFunc* func) { // Target-intrinsic functions should never be emitted @@ -3329,6 +3343,15 @@ void CLikeSourceEmitter::emitFunc(IRFunc* func) if (isTargetIntrinsic(func)) return; + bool shouldCloseExternCBlock = shouldWrappInExternCBlock(func); + if (shouldCloseExternCBlock) + { + // If this is a C++ `extern "C"` function, then we need to emit + // it as a C function, since that is what the C++ compiler will + // expect. + // + m_writer->emit("extern \"C\" {\n"); + } if(!isDefinition(func)) { @@ -3345,6 +3368,10 @@ void CLikeSourceEmitter::emitFunc(IRFunc* func) // emitSimpleFunc(func); } + if (shouldCloseExternCBlock) + { + m_writer->emit("}\n"); + } } void CLikeSourceEmitter::emitFuncDecorationsImpl(IRFunc* func) diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp index b00b713b5..178e15292 100644 --- a/source/slang/slang-ir-address-analysis.cpp +++ b/source/slang/slang-ir-address-analysis.cpp @@ -64,7 +64,7 @@ namespace Slang } } - if (!latestOperand || as(latestOperand)) + if (!latestOperand || as(latestOperand)) inst->insertBefore(earliestBlock->getFirstOrdinaryInst()); else inst->insertAfter(latestOperand); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index c906f93eb..10c8cdc51 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1647,7 +1647,7 @@ bool isLocalPointer(IRInst* ptrInst) // referencing something outside the function scope. // auto addr = getRootAddr(ptrInst); - return as(addr) || as(addr); + return as(addr) || as(addr); } void lowerSwizzledStores(IRModule* module, IRFunc* func) @@ -2067,7 +2067,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam } auto primalInst = cloneInst(&cloneEnv, builder, origParam); - if (auto primalParam = as(primalInst)) + if (auto primalParam = as(primalInst)) { SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); primalParam->removeFromParent(); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index cb743c06a..59653c4ae 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -228,7 +228,7 @@ struct ExtractPrimalFuncContext } else { - if (as(inst)) + if (as(inst)) builder.setInsertBefore(block->getFirstOrdinaryInst()); else builder.setInsertAfter(inst); @@ -427,7 +427,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( for (auto inst : instsToRemove) { - if (as(inst)) + if (as(inst)) removePhiArgs(inst); inst->removeAndDeallocate(); } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 5b90e2711..a6437e3e4 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -324,7 +324,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( result = originalPairType; return result; } - if (as(primalType)) + if (as(primalType)) { result = nullptr; return result; diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 8124f9210..b937fe052 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -471,7 +471,7 @@ public: auto block = as(inst->getParent()); if (block != funcInst->getFirstBlock()) { - auto paramIndex = getParamIndexInBlock(as(inst)); + auto paramIndex = getParamIndexInBlock(as(inst)); if (paramIndex != -1) { for (auto p : block->getPredecessors()) diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index 6e56ebd96..34b56bfef 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -172,7 +172,7 @@ IRLoop* isLoopPhi(IRParam* param) bool opCanBeConstExprByBackwardPass(IRInst* value) { if (value->getOp() == kIROp_Param) - return isLoopPhi(as(value)); + return isLoopPhi(as(value)); return opCanBeConstExpr(value->getOp()); } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 6b5d8e59a..a087e59d7 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -766,6 +766,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// An extern_cpp decoration marks the inst to emit its name without mangling for C++ interop. INST(ExternCppDecoration, externCpp, 1, 0) + // An externC decoration marks a function should be emitted inside an extern "C" block. + INST(ExternCDecoration, externC, 0, 0) + /// An dllImport decoration marks a function as imported from a DLL. Slang will generate dynamic function loading logic to use this function at runtime. INST(DllImportDecoration, dllImport, 2, 0) /// An dllExport decoration marks a function as an export symbol. Slang will generate a native wrapper function that is exported to DLL. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 0fd48f546..04f993838 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -543,6 +543,15 @@ struct IRExternCppDecoration : IRDecoration UnownedStringSlice getName() { return getNameOperand()->getStringSlice(); } }; +struct IRExternCDecoration : IRDecoration +{ + enum + { + kOp = kIROp_ExternCDecoration + }; + IR_LEAF_ISA(ExternCDecoration) +}; + struct IRDllImportDecoration : IRDecoration { enum @@ -4276,6 +4285,11 @@ public: addDecoration(value, kIROp_ExternCppDecoration, getStringValue(mangledName)); } + void addExternCDecoration(IRInst* value) + { + addDecoration(value, kIROp_ExternCDecoration); + } + void addForceInlineDecoration(IRInst* value) { addDecoration(value, kIROp_ForceInlineDecoration); diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index e37aa322e..b8f43c5f2 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -589,7 +589,9 @@ IRGeneric* cloneGenericImpl( auto originalParam = originalGeneric->getFirstParam(); ShortList> paramMapping; - for (; clonedParam && originalParam; (clonedParam = as(clonedParam->next)), (originalParam = as(originalParam->next))) + for (; clonedParam && originalParam; + (clonedParam = as(clonedParam->next)), + (originalParam = as(originalParam->next))) { paramMapping.add(KeyValuePair(clonedParam, originalParam)); } diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index f6e8a3458..d344590b1 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -776,7 +776,7 @@ struct PeepholeContext : InstPassBase break; UInt paramIndex = 0; auto prevParam = inst->getPrevInst(); - while (as(prevParam)) + while (as(prevParam)) { prevParam = prevParam->getPrevInst(); paramIndex++; diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index c723902de..41665ddf7 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -193,7 +193,7 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst IRIntegerValue i = 0; for (auto field : structType->getFields()) { - auto tupleElement = builder.emitTargetTupleGetElement(field->getFieldType(), val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = builder.emitTargetTupleGetElement(translateToTupleType(builder, field->getFieldType()), val, builder.getIntValue(builder.getIntType(), i)); auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement); if (!convertedElement) return nullptr; @@ -315,23 +315,38 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) inst->removeAndDeallocate(); } -IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* sink = nullptr) +IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr) { - if (as(type)) + if (as(type) || as(type)) return type; switch (type->getOp()) { case kIROp_TensorViewType: return builder->getTorchTensorType(as(type)->getElementType()); - +#if 0 + case kIROp_VectorType: + { + // Create a new struct type representing the vector. + auto hostStructType = builder->createStructType(); + const char* names[4] = { "x", "y", "z", "w" }; + for (IRIntegerValue i = 0; i < getIntVal(as(type)->getElementCount()); i++) + { + auto key = builder->createStructKey(); + if (i < 4) + builder->addNameHintDecoration(key, UnownedStringSlice(names[i])); + builder->createStructField(hostStructType, key, as(type)->getElementType()); + } + return hostStructType; + } +#endif case kIROp_StructType: { // Create a new struct type with translated fields. List fieldTypes; for (auto field : as(type)->getFields()) { - fieldTypes.add(translateToHostType(builder, field->getFieldType())); + fieldTypes.add(translateToHostType(builder, field->getFieldType(), func)); } auto hostStructType = builder->createStructType(); @@ -348,12 +363,14 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* si } if (sink) - sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type); + sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type, func); return nullptr; } IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst) { + if (hostType == cudaType) + return inst; if (as(hostType) && as(cudaType)) return inst; @@ -361,7 +378,18 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp { case kIROp_TensorViewType: return builder->emitMakeTensorView(cudaType, inst); - +#if 0 + case kIROp_VectorType: + { + List args; + auto hostStructType = cast(hostType); + for (auto field : hostStructType->getFields()) + { + args.add(builder->emitFieldExtract(field->getFieldType(), inst, field->getKey())); + } + return builder->emitMakeVector(cudaType, args); + } +#endif case kIROp_StructType: { auto cudaStructType = cast(cudaType); @@ -522,7 +550,7 @@ void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* host IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, DiagnosticSink* sink, IRType** outType = nullptr) { - auto type = translateToHostType(builder, param->getDataType(), sink); + auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink); if (outType) *outType = type; auto hostParam = builder->emitParam(type); @@ -859,6 +887,7 @@ void handleAutoBindNames(IRModule* module) nameBuilder << "__kernel__" << externCppHint->getName(); externCppHint->removeAndDeallocate(); builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice()); + builder.addExternCDecoration(globalInst); } } } @@ -915,7 +944,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well. if (func->findDecoration()) + { builder.addCudaKernelDecoration(wrapperFunc); + builder.addExternCDecoration(wrapperFunc); + } // Add an auto-pybind-cuda decoration to the wrapper function to further generate the // host-side binding for the derivative kernel. @@ -971,7 +1003,10 @@ void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well. if (func->findDecoration()) + { builder.addCudaKernelDecoration(wrapperFunc); + builder.addExternCDecoration(wrapperFunc); + } // Add an auto-pybind-cuda decoration to the wrapper function to further generate the // host-side binding for the derivative kernel. diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp index 5ae858256..d874514fe 100644 --- a/source/slang/slang-ir-sccp.cpp +++ b/source/slang/slang-ir-sccp.cpp @@ -1022,7 +1022,7 @@ struct SCCPContext // that provide arguments. We will see that logic shortly, when // handling `IRUnconditionalBranch`. // - if(as(inst)) + if(as(inst)) return; // We want to special-case terminator instructions here, diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 44a8909e4..e00c24bdc 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -205,7 +205,7 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst) auto rootAddr = getRootAddr(addr); if (isGlobalOrUnknownMutableAddress(func, rootAddr)) return true; - if (as(rootAddr)) + if (as(rootAddr)) return true; // If we can't find the address from our map, we conservatively assume it is an unknown address. diff --git a/source/slang/slang-ir-ssa-register-allocate.cpp b/source/slang/slang-ir-ssa-register-allocate.cpp index a93a3a8f4..ce502b454 100644 --- a/source/slang/slang-ir-ssa-register-allocate.cpp +++ b/source/slang/slang-ir-ssa-register-allocate.cpp @@ -103,7 +103,7 @@ struct RegisterAllocateContext bool isUseOfParamAfterPhiAssignment(IRDominatorTree* dom, IRUse* useToTest, IRInst* phiParam, IRInst* phiArg) { - IRParam* param = as(phiParam); + IRParam* param = as(phiParam); if (!param) return false; IRUse* branchUse = nullptr; diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 5ead1a1f4..4a92d263e 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -617,7 +617,7 @@ void removeLinkageDecorations(IRGlobalValueWithCode* func) void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst) { - if (as(inst)) + if (as(inst)) { SLANG_RELEASE_ASSERT(as(inst->getParent())); auto lastParam = as(inst->getParent())->getLastParam(); @@ -631,7 +631,7 @@ void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst) void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst) { - if (as(inst)) + if (as(inst)) { SLANG_RELEASE_ASSERT(as(inst->getParent())); auto lastParam = as(inst->getParent())->getLastParam(); @@ -818,7 +818,7 @@ void moveParams(IRBlock* dest, IRBlock* src) for (auto param = src->getFirstChild(); param;) { auto nextInst = param->getNextInst(); - if (as(param) || as(param)) + if (as(param) || as(param)) { param->insertAtEnd(dest); } diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index 2973e1ee5..af793bb54 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -80,7 +80,7 @@ namespace Slang validate(context, state <= kState_AfterDecoration, child, "decorations must come before other child instructions"); state = kState_AfterDecoration; } - else if( as(child) ) + else if( as(child) ) { validate(context, state <= kState_AfterParam, child, "parameters must come before ordinary instructions"); state = kState_AfterParam; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 95d6c9e1f..ff6c8c39e 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -304,12 +304,12 @@ namespace Slang IRParam* IRParam::getNextParam() { - return as(getNextInst()); + return as(getNextInst()); } IRParam* IRParam::getPrevParam() { - return as(getPrevInst()); + return as(getPrevInst()); } // IRArrayTypeBase @@ -472,7 +472,7 @@ namespace Slang // If the last instruction is a parameter, then // there are no ordinary instructions, so the last // one is a null pointer. - if (as(inst)) + if (as(inst)) return nullptr; // Otherwise the last instruction is the last "ordinary" @@ -1643,8 +1643,8 @@ namespace Slang // instructions, so they need to come after // any parameters of the parent. // - while(auto param = as(insertBeforeInst)) - insertBeforeInst = param->getNextInst(); + while (insertBeforeInst && insertBeforeInst->getOp() == kIROp_Param) + insertBeforeInst = insertBeforeInst->getNextInst(); // For instructions that will be placed at module scope, // we don't care about relative ordering, but for everything @@ -6425,14 +6425,14 @@ namespace Slang // First walk through any `param` instructions, // so that we can format them nicely - if (auto firstParam = as(inst)) + if (auto firstParam = as(inst)) { dump(context, "(\n"); context->indent += 2; for(;;) { - auto param = as(inst); + auto param = as(inst); if (!param) break; diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index c801f156e..ef803b642 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -828,31 +828,42 @@ struct IRInst void _insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent); }; -template +enum class IRDynamicCastBehavior +{ + Unwrap, NoUnwrap +}; + +template T* dynamicCast(IRInst* inst) { - if (inst && T::isaImpl(inst->getOp())) + if (!inst) return nullptr; + if (T::isaImpl(inst->getOp())) return static_cast(inst); + if constexpr(behavior == IRDynamicCastBehavior::Unwrap) + { + if (inst->getOp() == kIROp_AttributedType) + return dynamicCast(inst->getOperand(0)); + } return nullptr; } -template +template const T* dynamicCast(const IRInst* inst) { - return dynamicCast(const_cast(inst)); + return dynamicCast(const_cast(inst)); } // `dynamic_cast` equivalent (we just use dynamicCast) -template +template T* as(IRInst* inst) { - return dynamicCast(inst); + return dynamicCast(inst); } -template +template const T* as(const IRInst* inst) { - return dynamicCast(inst); + return dynamicCast(inst); } // `static_cast` equivalent, with debug validation @@ -1228,7 +1239,7 @@ struct IRBlock : IRInst // instructions at the start of the block. These play // the role of function parameters for the entry block // of a function, and of phi nodes in other blocks. - IRParam* getFirstParam() { return as(getFirstInst()); } + IRParam* getFirstParam() { return as(getFirstInst()); } IRParam* getLastParam(); IRInstList getParams() { -- cgit v1.2.3