diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-09-19 18:51:24 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-19 18:51:24 -0400 |
| commit | 739c3a7b53dc6489065fcd5e9f0a04370c5f9c8f (patch) | |
| tree | 593c86cbc184476479c66554cc6784b454bdec66 /source/slang | |
| parent | 359fdc9d556b4c493c588c5b8f93df85933634f8 (diff) | |
Added `[AutoPyBindCUDA]` for automatic kernel binding + `[PyExport]` for exporting type information (#3209)
* Initial: add a DiffTensor impl
* Auto-binding and diff tensor implementations now work
* Refactored diff-tensor implementation + added py-export for struct types
* Cleanup
* Update slang-ir-pytorch-cpp-binding.cpp
* Updated test names
* Update autodiff-data-flow.slang.expected
* Add more versions of load/store & default generic args for DiffTensorView.
* Add diagnostic for default generic arg and more tests
* Add more `[AutoPyBind]` tests
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/core.meta.slang | 6 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 292 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-check-type.cpp | 37 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.h | 42 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 68 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-types.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 673 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 11 |
18 files changed, 1163 insertions, 33 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 956a5b29a..e989e4ffa 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2263,6 +2263,12 @@ attribute_syntax [CudaHost] : CudaHostAttribute; __attributeTarget(FuncDecl) attribute_syntax [CudaKernel] : CudaKernelAttribute; +__attributeTarget(FuncDecl) +attribute_syntax[AutoPyBindCUDA] : AutoPyBindCudaAttribute; + +__attributeTarget(AggTypeDecl) +attribute_syntax [PyExport(name: String)] : PyExportAttribute; + __attributeTarget(InterfaceDecl) attribute_syntax [COM(guid: String)] : ComInterfaceAttribute; diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 423b6bfd0..495b6b989 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -264,6 +264,298 @@ extension TensorView<float> void InterlockedCompareExchange(vector<uint, N> index, float compare, float val); } +interface IDiffTensorWrapper +{ + __generic<T : __BuiltinFloatingPointType> + T load_forward(uint offset); + + __generic<T : __BuiltinFloatingPointType> + T load_forward_2(uint2 offset); + + __generic<T : __BuiltinFloatingPointType> + T load_forward_3(uint3 offset); + + __generic<T : __BuiltinFloatingPointType> + T load_forward_4(uint4 offset); + + __generic<T : __BuiltinFloatingPointType> + void load_backward(uint offset, T dOut); + + __generic<T : __BuiltinFloatingPointType> + void load_backward_2(uint2 offset, T dOut); + + __generic<T : __BuiltinFloatingPointType> + void load_backward_3(uint3 offset, T dOut); + + __generic<T : __BuiltinFloatingPointType> + void load_backward_4(uint4 offset, T dOut); + + __generic<T : __BuiltinFloatingPointType> + void store_forward(uint offset, T dx); + + __generic<T : __BuiltinFloatingPointType> + void store_forward_2(uint2 offset, T dx); + + __generic<T : __BuiltinFloatingPointType> + void store_forward_3(uint3 offset, T dx); + + __generic<T : __BuiltinFloatingPointType> + void store_forward_4(uint4 offset, T dx); + + __generic<T : __BuiltinFloatingPointType> + T store_backward(uint offset); + + __generic<T : __BuiltinFloatingPointType> + T store_backward_2(uint2 offset); + + __generic<T : __BuiltinFloatingPointType> + T store_backward_3(uint3 offset); + + __generic<T : __BuiltinFloatingPointType> + T store_backward_4(uint4 offset); +}; + +struct AtomicAdd : IDiffTensorWrapper +{ + TensorView<float> diff; + + __generic<T : __BuiltinFloatingPointType> + T load_forward(uint i) + { + return __realCast<T, float>(diff.load(i)); + } + + __generic<T : __BuiltinFloatingPointType> + T load_forward_2(uint2 i) + { + return __realCast<T, float>(diff.load(i.x, i.y)); + } + + __generic<T : __BuiltinFloatingPointType> + T load_forward_3(uint3 i) + { + return __realCast<T, float>(diff.load(i.x, i.y, i.z)); + } + + __generic<T : __BuiltinFloatingPointType> + T load_forward_4(uint4 i) + { + return __realCast<T, float>(diff.load(i.x, i.y, i.z, i.w)); + } + + __generic<T : __BuiltinFloatingPointType> + void load_backward(uint i, T dOut) + { + float oldVal; + diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal); + } + + __generic<T : __BuiltinFloatingPointType> + void load_backward_2(uint2 i, T dOut) + { + float oldVal; + diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal); + } + + __generic<T : __BuiltinFloatingPointType> + void load_backward_3(uint3 i, T dOut) + { + float oldVal; + diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal); + } + + __generic<T : __BuiltinFloatingPointType> + void load_backward_4(uint4 i, T dOut) + { + float oldVal; + diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal); + } + + __generic<T : __BuiltinFloatingPointType> + void store_forward(uint i, T dx) + { + diff.store(i, __realCast<float, T>(dx)); + } + + __generic<T : __BuiltinFloatingPointType> + void store_forward_2(uint2 i, T dx) + { + diff.store(i.x, i.y, __realCast<float, T>(dx)); + } + + __generic<T : __BuiltinFloatingPointType> + void store_forward_3(uint3 i, T dx) + { + diff.store(i.x, i.y, i.z, __realCast<float, T>(dx)); + } + + __generic<T : __BuiltinFloatingPointType> + void store_forward_4(uint4 i, T dx) + { + diff.store(i.x, i.y, i.z, i.w, __realCast<float, T>(dx)); + } + + __generic<T : __BuiltinFloatingPointType> + T store_backward(uint i) + { + float oldVal; + diff.InterlockedExchange(i, (float)0, oldVal); + return __realCast<T, float>(oldVal); + } + + __generic<T : __BuiltinFloatingPointType> + T store_backward_2(uint2 i) + { + float oldVal; + diff.InterlockedExchange(i, (float)0, oldVal); + return __realCast<T, float>(oldVal); + } + + __generic<T : __BuiltinFloatingPointType> + T store_backward_3(uint3 i) + { + float oldVal; + diff.InterlockedExchange(i, (float)0, oldVal); + return __realCast<T, float>(oldVal); + } + + __generic<T : __BuiltinFloatingPointType> + T store_backward_4(uint4 i) + { + float oldVal; + diff.InterlockedExchange(i, (float)0, oldVal); + return __realCast<T, float>(oldVal); + } +}; + +__generic<T: __BuiltinFloatingPointType = float, A : IDiffTensorWrapper = AtomicAdd> +struct DiffTensorView +{ + TensorView<T> primal; + A diff; + + uint size(uint i) + { + return primal.size(i); + } + + [BackwardDerivative(load_backward)] + [ForwardDerivative(load_forward)] + T load(uint i) { return primal.load(i); } + + [BackwardDerivative(load_backward)] + [ForwardDerivative(load_forward)] + T load(uint2 i) { return primal.load(i.x, i.y); } + + [BackwardDerivative(load_backward)] + [ForwardDerivative(load_forward)] + T load(uint3 i) { return primal.load(i.x, i.y, i.z); } + + [BackwardDerivative(load_backward)] + [ForwardDerivative(load_forward)] + T load(uint4 i) { return primal.load(i.x, i.y, i.z, i.w); } + + DifferentialPair<T> load_forward(uint x) + { + return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T>(x))); + } + + DifferentialPair<T> load_forward(uint2 x) + { + return diffPair(primal.load(x.x, x.y), reinterpret<T.Differential, T>(diff.load_forward_2<T>(x))); + } + + DifferentialPair<T> load_forward(uint3 x) + { + return diffPair(primal.load(x.x, x.y, x.z), reinterpret<T.Differential, T>(diff.load_forward_3<T>(x))); + } + + DifferentialPair<T> load_forward(uint4 x) + { + return diffPair(primal.load(x.x, x.y, x.z, x.w), reinterpret<T.Differential, T>(diff.load_forward_4<T>(x))); + } + + void load_backward(uint x, T.Differential dOut) + { + diff.load_backward<T>(x, reinterpret<T, T.Differential>(dOut)); + } + + void load_backward(uint2 x, T.Differential dOut) + { + diff.load_backward_2<T>(x, reinterpret<T, T.Differential>(dOut)); + } + + void load_backward(uint3 x, T.Differential dOut) + { + diff.load_backward_3<T>(x, reinterpret<T, T.Differential>(dOut)); + } + + void load_backward(uint4 x, T.Differential dOut) + { + diff.load_backward_4<T>(x, reinterpret<T, T.Differential>(dOut)); + } + + [BackwardDerivative(store_backward)] + [ForwardDerivative(store_forward)] + void store(uint x, T val) { primal.store(x, val); } + + [BackwardDerivative(store_backward)] + [ForwardDerivative(store_forward)] + void store(uint2 x, T val) { primal.store(x.x, x.y, val); } + + [BackwardDerivative(store_backward)] + [ForwardDerivative(store_forward)] + void store(uint3 x, T val) { primal.store(x.x, x.y, x.z, val); } + + [BackwardDerivative(store_backward)] + [ForwardDerivative(store_forward)] + void store(uint4 x, T val) { primal.store(x.x, x.y, x.z, x.w, val); } + + void store_forward(uint x, DifferentialPair<T> dpval) + { + primal.store(x, dpval.p); + diff.store_forward<T>(x, reinterpret<T, T.Differential>(dpval.d)); + } + + void store_forward(uint2 x, DifferentialPair<T> dpval) + { + primal.store(x.x, x.y, dpval.p); + diff.store_forward_2<T>(x, reinterpret<T, T.Differential>(dpval.d)); + } + + void store_forward(uint3 x, DifferentialPair<T> dpval) + { + primal.store(x.x, x.y, x.z, dpval.p); + diff.store_forward_3<T>(x, reinterpret<T, T.Differential>(dpval.d)); + } + + void store_forward(uint4 x, DifferentialPair<T> dpval) + { + primal.store(x.x, x.y, x.z, x.w, dpval.p); + diff.store_forward_4<T>(x, reinterpret<T, T.Differential>(dpval.d)); + } + + void store_backward(uint x, inout DifferentialPair<T> dpval) + { + dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T>(x))); + } + + void store_backward(uint2 x, inout DifferentialPair<T> dpval) + { + dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_2<T>(x))); + } + + void store_backward(uint3 x, inout DifferentialPair<T> dpval) + { + dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_3<T>(x))); + } + + void store_backward(uint4 x, inout DifferentialPair<T> dpval) + { + dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_4<T>(x))); + } +}; + /// Represents the handle of a Torch tensor object. __generic<T> __intrinsic_type($(kIROp_TorchTensorType)) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index d70651636..af5823db4 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1134,6 +1134,18 @@ class CudaHostAttribute : public Attribute SLANG_AST_CLASS(CudaHostAttribute) }; +class AutoPyBindCudaAttribute : public Attribute +{ + SLANG_AST_CLASS(AutoPyBindCudaAttribute) +}; + +class PyExportAttribute : public Attribute +{ + SLANG_AST_CLASS(PyExportAttribute) + + String name; +}; + class PreferRecomputeAttribute : public Attribute { SLANG_AST_CLASS(PreferRecomputeAttribute) diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 53283fbe1..e8ae28c04 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -757,6 +757,19 @@ namespace Slang knownBuiltinAttr->name = name; } + else if (auto pyExportAttr = as<PyExportAttribute>(attr)) + { + // Check name string. + SLANG_ASSERT(attr->args.getCount() == 1); + + String name; + if(!checkLiteralStringVal(attr->args[0], &name)) + { + return false; + } + + pyExportAttr->name = name; + } else { if(attr->args.getCount() == 0) diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index d5d3e5a5d..5967da8b0 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -273,7 +273,8 @@ namespace Slang auto genericDeclRef = genericDeclRefType->getDeclRef(); ensureDecl(genericDeclRef, DeclCheckState::CanSpecializeGeneric); - List<Expr*> args; + List<Val*> args; + List<Val*> witnessArgs; for (Decl* member : genericDeclRef.getDecl()->members) { if (auto typeParam = as<GenericTypeParamDecl>(member)) @@ -290,7 +291,7 @@ namespace Slang // TODO: this is one place where syntax should get cloned! if (outProperType) - args.add(typeParam->initType.exp); + args.add(ExtractGenericArgVal(typeParam->initType.exp)); } else if (auto valParam = as<GenericValueParamDecl>(member)) { @@ -305,14 +306,42 @@ namespace Slang } // TODO: this is one place where syntax should get cloned! if (outProperType) - args.add(valParam->initExpr); + args.add(ExtractGenericArgVal(valParam->initExpr)); + } + else if (auto constraintParam = as<GenericTypeConstraintDecl>(member)) + { + auto genericParam = as<DeclRefType>(constraintParam->sub.type)->getDeclRef(); + if (!genericParam) + return false; + auto genericTypeParamDecl = as<GenericTypeParamDecl>(genericParam.getDecl()); + if (!genericTypeParamDecl) + return false; + auto defaultType = CheckProperType(genericTypeParamDecl->initType); + auto witness = tryGetSubtypeWitness(defaultType, CheckProperType(constraintParam->sup)); + if (!witness) + { + // diagnose + getSink()->diagnose( + genericTypeParamDecl->initType.exp, + Diagnostics::typeArgumentDoesNotConformToInterface, + defaultType, + constraintParam->sup); + + SLANG_UNEXPECTED("default type argument does not conform to interface"); + return false; + } + witnessArgs.add(witness); } else { // ignore non-parameter members } } - result = InstantiateGenericType(genericDeclRef, args); + // Combine args and witnessArgs + args.addRange(witnessArgs); + + result = DeclRefType::create(getASTBuilder(), + getASTBuilder()->getGenericAppDeclRef(genericDeclRef, args.getArrayView())); } // default case: we expect this to already be a proper type diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 5a28c8188..14f45e802 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -726,6 +726,8 @@ DIAGNOSTIC(54004, Warning, unnecessaryHLSLMeshOutputModifier, "Unnecessary HLSL DIAGNOSTIC(55101, Error, invalidTorchKernelReturnType, "'$0' is not a valid return type for a pytorch kernel function.") DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid parameter type for a pytorch kernel function.") +DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'") + // // 8xxxx - Issues specific to a particular library/technology/platform/etc. // diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index fa6fd2b43..a74459954 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -2286,7 +2286,11 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO auto prec = getInfo(EmitOp::Postfix); needClose = maybeEmitParens(outerPrec, prec); emitOperand(inst->getOperand(0), leftSide(outerPrec, prec)); - m_writer->emit("->getBuffer()"); + if (as<IRPtrTypeBase>(inst->getOperand(0)->getDataType())) + m_writer->emit("->"); + else + m_writer->emit("."); + m_writer->emit("getBuffer()"); break; } case kIROp_MakeString: diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index c77f2a6ce..98c9c9803 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -45,6 +45,7 @@ #include "slang-ir-legalize-vector-types.h" #include "slang-ir-metadata.h" #include "slang-ir-optix-entry-point-uniforms.h" +#include "slang-ir-pytorch-cpp-binding.h" #include "slang-ir-restructure.h" #include "slang-ir-restructure-scoping.h" #include "slang-ir-sccp.h" @@ -369,6 +370,9 @@ Result linkAndOptimizeIR( // being passed to saturated_cooperation fuseCallsToSaturatedCooperation(irModule); + // Generate any requested derivative wrappers + generateDerivativeWrappers(irModule, sink); + // Next, we need to ensure that the code we emit for // the target doesn't contain any operations that would // be illegal on the target platform. For example, @@ -444,9 +448,11 @@ Result linkAndOptimizeIR( { case CodeGenTarget::PyTorchCppBinding: generatePyTorchCppBinding(irModule, sink); + handleAutoBindNames(irModule); break; case CodeGenTarget::CUDASource: removeTorchKernels(irModule); + handleAutoBindNames(irModule); break; default: break; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 681c69cd3..335b6572e 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -332,12 +332,10 @@ namespace Slang as<IRFuncType>(origFunc->getFullType())); diffFunc->setFullType(diffFuncType); - if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>()) + if (origFunc->findDecoration<IRNameHintDecoration>()) { - auto originalName = nameHint->getName(); - StringBuilder newNameSb; - newNameSb << "s_bwd_" << originalName; - builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); + auto newName = this->getTranscribedFuncName(&builder, origFunc); + builder.addNameHintDecoration(diffFunc, newName); } // Transfer checkpoint hint decorations @@ -492,6 +490,8 @@ namespace Slang builder.emitCallInst(builder.getVoidType(), propagateFunc, propagateArgs); builder.emitReturn(); + + addTranscribedFuncDecoration(builder, origFunc, cast<IRFunc>(header.differential)); return header; } diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index 15d558c22..68cb4e0c9 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -94,6 +94,9 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase // Transcribe a function definition. virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) = 0; + // Get transcribed function name from original name. + virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) = 0; + // Splits and transpose the parameter block. // After this operation, the parameter block will contain parameters for both the future // primal func and the future propagate func. @@ -160,6 +163,19 @@ struct BackwardDiffPrimalTranscriber : BackwardDiffTranscriberBase { return kIROp_BackwardDerivativePrimalDecoration; } + virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) override + { + if (auto nameHint = func->findDecoration<IRNameHintDecoration>()) + { + StringBuilder sbuilder; + sbuilder << "s_primal_ctx_" << nameHint->getName(); + return builder->getStringValue(sbuilder.getUnownedSlice()); + } + else + { + return builder->getStringValue(String("s_primal_ctx_anonymous").getUnownedSlice()); + } + } }; struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase @@ -196,6 +212,19 @@ struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase { return kIROp_BackwardDerivativePropagateDecoration; } + virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) override + { + if (auto nameHint = func->findDecoration<IRNameHintDecoration>()) + { + StringBuilder sbuilder; + sbuilder << "s_bwd_prop_" << nameHint->getName(); + return builder->getStringValue(sbuilder.getUnownedSlice()); + } + else + { + return builder->getStringValue(String("s_bwd_prop_anonymous").getUnownedSlice()); + } + } }; // A backward derivative function combines both primal + propagate functions and accepts no @@ -235,6 +264,19 @@ struct BackwardDiffTranscriber : BackwardDiffTranscriberBase { builder->addBackwardDerivativeDecoration(inst, diffFunc); } + virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) override + { + if (auto nameHint = func->findDecoration<IRNameHintDecoration>()) + { + StringBuilder sbuilder; + sbuilder << "s_bwd_" << nameHint->getName(); + return builder->getStringValue(sbuilder.getUnownedSlice()); + } + else + { + return builder->getStringValue(String("s_bwd_anonymous").getUnownedSlice()); + } + } }; } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 4c5608341..5b90e2711 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1039,6 +1039,7 @@ void stripDerivativeDecorations(IRInst* inst) } } + void stripAutoDiffDecorationsFromChildren(IRInst* parent) { for (auto inst : parent->getChildren()) diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 6171d9a75..8f01a574f 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -295,9 +295,6 @@ public: } } - if (differentiableOutputs == 0) - sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveOutput); - DifferentiableLevel requiredDiffLevel = DifferentiableLevel::Forward; if (isBackwardDifferentiableFunc(funcInst)) requiredDiffLevel = DifferentiableLevel::Backward; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 552a8af6d..b44b7b5d9 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -719,6 +719,11 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(CudaKernelDecoration, CudaKernel, 0, 0) INST(CudaHostDecoration, CudaHost, 0, 0) INST(TorchEntryPointDecoration, TorchEntryPoint, 0, 0) + INST(AutoPyBindCudaDecoration, AutoPyBindCUDA, 0, 0) + INST(CudaKernelForwardDerivativeDecoration, CudaKernelFwdDiffRef, 0, 0) + INST(CudaKernelBackwardDerivativeDecoration, CudaKernelBwdDiffRef, 0, 0) + INST(AutoPyBindExportInfoDecoration, PyBindExportFuncInfo, 0, 0) + INST(PyExportDecoration, PyExportDecoration, 0, 0) /// Used to mark parameters that are moved from entry point parameters to global params as coming from the entry point. INST(EntryPointParamDecoration, entryPointParam, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index b88df524a..63555c08d 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -460,6 +460,22 @@ struct IREntryPointDecoration : IRDecoration IR_SIMPLE_DECORATION(CudaHostDecoration) IR_SIMPLE_DECORATION(CudaKernelDecoration) +struct IRCudaKernelForwardDerivativeDecoration : IRDecoration +{ + enum { kOp = kIROp_CudaKernelForwardDerivativeDecoration }; + IR_LEAF_ISA(CudaKernelForwardDerivativeDecoration) + + IRInst* getForwardDerivativeFunc() { return getOperand(0); } +}; + +struct IRCudaKernelBackwardDerivativeDecoration : IRDecoration +{ + enum { kOp = kIROp_CudaKernelBackwardDerivativeDecoration }; + IR_LEAF_ISA(CudaKernelBackwardDerivativeDecoration) + + IRInst* getBackwardDerivativeFunc() { return getOperand(0); } +}; + struct IRGeometryInputPrimitiveTypeDecoration: IRDecoration { IR_PARENT_ISA(GeometryInputPrimitiveTypeDecoration) @@ -566,6 +582,33 @@ struct IRTorchEntryPointDecoration : IRDecoration UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); } }; +struct IRAutoPyBindCudaDecoration : IRDecoration +{ + enum + { + kOp = kIROp_AutoPyBindCudaDecoration + }; + IR_LEAF_ISA(AutoPyBindCudaDecoration) + + IRStringLit* getFunctionNameOperand() { return cast<IRStringLit>(getOperand(0)); } + UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); } +}; + +IR_SIMPLE_DECORATION(AutoPyBindExportInfoDecoration) + +struct IRPyExportDecoration : IRDecoration +{ + enum + { + kOp = kIROp_PyExportDecoration + }; + IR_LEAF_ISA(PyExportDecoration) + + IRStringLit* getExportNameOperand() { return cast<IRStringLit>(getOperand(0)); } + UnownedStringSlice getExportName() { return getExportNameOperand()->getStringSlice(); } +}; + + struct IRKnownBuiltinDecoration : IRDecoration { enum @@ -4368,6 +4411,16 @@ public: addDecoration(value, kIROp_TorchEntryPointDecoration, getStringValue(functionName)); } + void addAutoPyBindCudaDecoration(IRInst* value, UnownedStringSlice const& functionName) + { + addDecoration(value, kIROp_AutoPyBindCudaDecoration, getStringValue(functionName)); + } + + void addPyExportDecoration(IRInst* value, UnownedStringSlice const& exportName) + { + addDecoration(value, kIROp_PyExportDecoration, getStringValue(exportName)); + } + void addCudaDeviceExportDecoration(IRInst* value, UnownedStringSlice const& functionName) { addDecoration(value, kIROp_CudaDeviceExportDecoration, getStringValue(functionName)); @@ -4383,6 +4436,21 @@ public: addDecoration(value, kIROp_CudaKernelDecoration); } + void addCudaKernelForwardDerivativeDecoration(IRInst* value, IRInst* func) + { + addDecoration(value, kIROp_CudaKernelForwardDerivativeDecoration, func); + } + + void addCudaKernelBackwardDerivativeDecoration(IRInst* value, IRInst* func) + { + addDecoration(value, kIROp_CudaKernelBackwardDerivativeDecoration, func); + } + + void addAutoPyBindExportInfoDecoration(IRInst* value) + { + addDecoration(value, kIROp_AutoPyBindExportInfoDecoration); + } + void addEntryPointDecoration(IRInst* value, Profile profile, UnownedStringSlice const& name, UnownedStringSlice const& moduleName) { IRInst* operands[] = { getIntValue(getIntType(), profile.raw), getStringValue(name), getStringValue(moduleName) }; diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index cb1cf3db3..4b519b065 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -3659,7 +3659,14 @@ struct IRTypeLegalizationPass // * `i` is a child of `inst`. // if (legalVal.flavor == LegalVal::Flavor::simple) + { + // The resulting inst may be different from the one we added to the + // worklist, so ensure that the appropriate flags are set. + // + setHasBeenAddedOrProcessed(legalVal.irValue); + inst = legalVal.irValue; + } for( auto use = inst->firstUse; use; use = use->nextUse ) { diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index d59d57474..c0adef436 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -2,11 +2,14 @@ #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-diagnostics.h" +#include "slang-ir-autodiff.h" namespace Slang { // Convert a type to a target tuple type. -static IRType* translateToTupleType(IRBuilder& builder, IRType* type) +static IRType* translateToTupleType( + IRBuilder& builder, + IRType* type) { if (as<IRVoidType>(type)) return type; @@ -312,34 +315,490 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) inst->removeAndDeallocate(); } +IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* sink = nullptr) +{ + if (as<IRBasicType>(type)) + return type; + + switch (type->getOp()) + { + case kIROp_TensorViewType: + return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType()); + + case kIROp_StructType: + { + // Create a new struct type with translated fields. + List<IRType*> fieldTypes; + for (auto field : as<IRStructType>(type)->getFields()) + { + fieldTypes.add(translateToHostType(builder, field->getFieldType())); + } + auto hostStructType = builder->createStructType(); + + // Add fields to the struct. + for (UInt i = 0; i < (UInt)fieldTypes.getCount(); i++) + { + builder->createStructField(hostStructType, builder->createStructKey(), fieldTypes[i]); + } + + return hostStructType; + } + default: + break; + } + + if (sink) + sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type); + return nullptr; +} + +IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst) +{ + if (as<IRBasicType>(hostType) && as<IRBasicType>(cudaType)) + return inst; + + switch (cudaType->getOp()) + { + case kIROp_TensorViewType: + return builder->emitMakeTensorView(cudaType, inst); + + case kIROp_StructType: + { + auto cudaStructType = cast<IRStructType>(cudaType); + auto hostStructType = cast<IRStructType>(hostType); + + List<IRStructField*> cudaFields; + for (auto field : cudaStructType->getFields()) + cudaFields.add(field); + + List<IRStructField*> hostFields; + for (auto field : hostStructType->getFields()) + hostFields.add(field); + + List<IRInst*> resultFields; + for (auto ii = 0; ii < cudaFields.getCount(); ii++) + { + auto cudaField = cudaFields[ii]; + auto hostField = hostFields[ii]; + auto cudaFieldType = cudaField->getFieldType(); + auto hostFieldType = hostField->getFieldType(); + auto castedField = castHostToCUDAType( + builder, + hostFieldType, + cudaFieldType, + builder->emitFieldExtract(hostFieldType, inst, hostField->getKey())); + + SLANG_RELEASE_ASSERT(castedField); + resultFields.add(castedField); + } + + return builder->emitMakeStruct(cudaType, (UInt)resultFields.getCount(), resultFields.getBuffer()); + } + + default: + break; + } + + // If translateToHostType worked correctly, we shouldn't get here. + SLANG_UNREACHABLE("unhandled type"); +} + +void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* hostFunc) +{ + // Given a func with torch binding, we'll generate a reflection function that returns + // a tuple where the first element is another tuple of parameter names, the second + // element is a string containing the name of the fwd-diff function, and the third + // element is a string containing the name of the bwd-diff function. + // + + // Create a new function. + auto reflectionFunc = builder->createFunc(); + builder->setInsertInto(reflectionFunc); + builder->emitBlock(); + + // Go through func & generate a tuple of parameter names. + List<IRInst*> paramNames; + List<IRInst*> paramTypeNames; + UIndex paramCount = 0; + for (auto param : hostFunc->getFirstBlock()->getParams()) + { + if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) + { + paramNames.add(builder->emitGetNativeString(builder->getStringValue(nameHint->getName()))); + } + else + { + StringBuilder argNameBuilder; + argNameBuilder << "param"; + argNameBuilder << paramCount; + + paramNames.add(builder->emitGetNativeString(builder->getStringValue(argNameBuilder.getUnownedSlice()))); + } + paramCount++; + } + + for (auto param : kernelFunc->getParams()) + { + // Check for py-export decoration. + if (auto pyExportHint = param->getDataType()->findDecoration<IRPyExportDecoration>()) + { + paramTypeNames.add( + builder->emitGetNativeString( + builder->getStringValue( + pyExportHint->getExportName()))); + } + else + { + paramTypeNames.add( + builder->emitGetNativeString( + builder->getStringValue( + UnownedStringSlice("")))); + } + } + + // Create a target-tuple-type for the names + auto paramNamesTupleType = builder->getTargetTupleType( + (UInt)paramNames.getCount(), + List<IRType*>().makeRepeated(builder->getNativeStringType(), paramNames.getCount()).getBuffer()); + auto paramNamesTuple = builder->emitMakeTargetTuple(paramNamesTupleType, paramNames.getCount(), paramNames.getBuffer()); + + // Create a target-tuple-type for the type names + auto paramTypeNamesTupleType = builder->getTargetTupleType( + (UInt)paramTypeNames.getCount(), + List<IRType*>().makeRepeated(builder->getNativeStringType(), paramTypeNames.getCount()).getBuffer()); + auto paramTypeNamesTuple = builder->emitMakeTargetTuple(paramTypeNamesTupleType, paramTypeNames.getCount(), paramTypeNames.getBuffer()); + + // Find the fwd-diff function name (blank string indicates no fwd-diff) + IRInst* fwdDiffName = builder->getStringValue(UnownedStringSlice("")); + if (auto fwdDiffHint = kernelFunc->findDecoration<IRCudaKernelForwardDerivativeDecoration>()) + { + auto fwdDiffFunc = fwdDiffHint->getForwardDerivativeFunc(); + + if (auto fwdDiffFuncExternHint = fwdDiffFunc->findDecoration<IRExternCppDecoration>()) + { + fwdDiffName = builder->emitGetNativeString(builder->getStringValue(fwdDiffFuncExternHint->getName())); + } + } + + // Find the bwd-diff function name (blank string indicates no bwd-diff) + IRInst* bwdDiffName = builder->getStringValue(UnownedStringSlice("")); + if (auto bwdDiffHint = kernelFunc->findDecoration<IRCudaKernelBackwardDerivativeDecoration>()) + { + auto bwdDiffFunc = bwdDiffHint->getBackwardDerivativeFunc(); + + if (auto bwdDiffFuncExternHint = bwdDiffFunc->findDecoration<IRExternCppDecoration>()) + { + bwdDiffName = builder->emitGetNativeString(builder->getStringValue(bwdDiffFuncExternHint->getName())); + } + } + + auto stringType = builder->getNativeStringType(); + auto returnTupleType = builder->getTargetTupleType( + 4, + List<IRType*>(paramNamesTupleType, paramTypeNamesTupleType, stringType, stringType).getBuffer()); + + // Create a target-tuple-type for the names + auto returnTupleArgs = List<IRInst*>( paramNamesTuple, paramTypeNamesTuple, fwdDiffName, bwdDiffName ); + auto returnTuple = builder->emitMakeTargetTuple( + returnTupleType, + returnTupleArgs.getCount(), + returnTupleArgs.getBuffer()); + builder->emitReturn(returnTuple); + + // Set function type. + auto funcType = builder->getFuncType(List<IRType*>(), returnTupleType); + reflectionFunc->setFullType(funcType); + + // Set function name. + StringBuilder reflFuncExportName; + auto hostFuncExportName = hostFunc->findDecoration<IRExternCppDecoration>()->getName(); + reflFuncExportName << "__funcinfo__" << hostFuncExportName; + + builder->addExternCppDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice()); + builder->addTorchEntryPointDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice()); + builder->addPublicDecoration(reflectionFunc); + builder->addKeepAliveDecoration(reflectionFunc); +} + +IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, DiagnosticSink* sink, IRType** outType = nullptr) +{ + auto typeMap = [&](IRType* t) -> IRType* { + if (auto tensorViewType = as<IRTensorViewType>(t)) + return builder->getTorchTensorType(tensorViewType->getElementType()); + }; + + auto type = translateToHostType(builder, param->getDataType(), sink); + if (outType) + *outType = type; + auto hostParam = builder->emitParam(type); + // Add a namehint to the param by appending the suffix "_host". + if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) + { + builder->addNameHintDecoration(hostParam, nameHint->getName()); + } + + // Then cast the param to the appropriate type. + if (auto castedParam = castHostToCUDAType(builder, type, param->getDataType(), hostParam)) + return castedParam; + + return nullptr; +} + +void markTypeForPyExport(IRType* type, DiagnosticSink* sink) +{ + // If it's a basic type, we're done. + if (as<IRBasicType>(type) || as<IRVoidType>(type)) + return; + + // If it's a struct type, mark for py-export. + if (auto structType = as<IRStructType>(type)) + { + IRBuilder builder(structType->getModule()); + + // If it already has a py-export decoration, we're done. + if (!structType->findDecoration<IRPyExportDecoration>()) + { + // Look for a name hint. + UnownedStringSlice nameHint; + if (auto nameHintDecoration = structType->findDecoration<IRNameHintDecoration>()) + nameHint = nameHintDecoration->getName(); + else + { + // If there's no name hint, we can't export this type. + SLANG_UNEXPECTED("struct marked for export has no name"); + } + + builder.addPyExportDecoration(structType, nameHint); + } + + for (auto field : structType->getFields()) + { + markTypeForPyExport(field->getFieldType(), sink); + } + return; + } +} + +void generateReflectionForType(IRType* type, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + // Emit a function that returns a py::list. + // The list will contain the names of all the fields of the type. + // + + // TODO: Fix this to avoid emitting the same type reflection multiple times. + if (!type->findDecoration<IRPyExportDecoration>()) + return; + + IRBuilder builder(type->getModule()); + + auto reflFunc = builder.createFunc(); + builder.setInsertInto(reflFunc); + builder.emitBlock(); + + List<IRInst*> fieldNames; + List<IRInst*> fieldTypeNames; + + switch (type->getOp()) + { + case kIROp_StructType: + { + for (auto field : as<IRStructType>(type)->getFields()) + { + auto structKey = field->getKey(); + // Look for a name hint. + if (auto nameHintDecoration = structKey->findDecoration<IRNameHintDecoration>()) + fieldNames.add(builder.emitGetNativeString(builder.getStringValue(nameHintDecoration->getName()))); + else + fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); + + if (!field->getFieldType()->findDecoration<IRPyExportDecoration>()) + { + fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); + continue; + } + + auto fieldType = field->getFieldType(); + + fieldTypeNames.add( + builder.emitGetNativeString( + builder.getStringValue(fieldType->findDecoration<IRPyExportDecoration>()->getExportName()))); + } + break; + } + default: + break; + } + + auto _nameListTupleType = builder.getTargetTupleType( + (UInt)fieldNames.getCount(), + List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldNames.getCount()).getBuffer()); + auto nameListTuple = builder.emitMakeTargetTuple(_nameListTupleType, (UInt)fieldNames.getCount(), fieldNames.getBuffer()); + + auto _typeNameListTupleType = builder.getTargetTupleType( + (UInt)fieldTypeNames.getCount(), + List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldTypeNames.getCount()).getBuffer()); + auto typeNameListTuple = builder.emitMakeTargetTuple(_typeNameListTupleType, (UInt)fieldTypeNames.getCount(), fieldTypeNames.getBuffer()); + + auto _nameAndTypeTupleType = builder.getTargetTupleType(2, List<IRType*>(_nameListTupleType, _typeNameListTupleType).getBuffer()); + auto nameAndTypeTuple = builder.emitMakeTargetTuple( + _nameAndTypeTupleType, + 2, + List<IRInst*>(nameListTuple, typeNameListTuple).getBuffer()); + builder.emitReturn(nameAndTypeTuple); + + // Set function type. + auto funcType = builder.getFuncType(List<IRType*>(), _nameAndTypeTupleType); + reflFunc->setFullType(funcType); + + // Set function name. + StringBuilder reflFuncExportName; + reflFuncExportName << "__typeinfo__" << type->findDecoration<IRPyExportDecoration>()->getExportName(); + + builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); + builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); + builder.addPublicDecoration(reflFunc); + builder.addKeepAliveDecoration(reflFunc); +} + +IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink) +{ + // Check that the function has an auto-bind decoration + if (!func->findDecoration<IRAutoPyBindCudaDecoration>()) + return nullptr; + + // We will create a CudaHost function that will call func. + // But before that, we need to determine the type of CudaHost. + // + // To determine the type, first we will append two uint3 parameters to the function. + // with the names "__blockSize" and "__gridSize", these will serve as input block and + // grid size parameters for the launch. + // + // Then, we will go over the parameters of func, and find a host-mapping for each type + // by calling mapTypeToCudaHostType(IRType*), which turns structs into tuples, and + // IRTensorViewType to IRTorchTensorType. + // + // Finally, we will create a CudaHost function and transfer the name of func over to + // the generated method. + // + // The function body will first perform any conversion logic needed to convert the + // parameters from the CudaHost types to the types of func, and then use dispatch_kernel + // to dispatch func with the given block and grid size. + // + + // Create new function. + IRBuilder builder(func->getModule()); + + auto hostFunc = builder.createFunc(); + builder.setInsertInto(hostFunc); + builder.emitBlock(); + + List<IRType*> hostParamTypes; + + // Add the two uint3 parameters + auto uint3Type = builder.getVectorType(builder.getUIntType(), 3); + + auto blockSizeParam = builder.emitParam(uint3Type); + hostParamTypes.add(uint3Type); + builder.addNameHintDecoration(blockSizeParam, UnownedStringSlice("__blockSize")); + + auto gridSizeParam = builder.emitParam(uint3Type); + hostParamTypes.add(uint3Type); + builder.addNameHintDecoration(gridSizeParam, UnownedStringSlice("__gridSize")); + + List<IRInst*> mappedParams; + for (auto param : func->getFirstBlock()->getParams()) + { + IRType* hostParamType; + mappedParams.add(generateHostParamForCUDAParam(&builder, param, sink, &hostParamType)); + hostParamTypes.add(hostParamType); + markTypeForPyExport(param->getDataType(), sink); // Should we be marking the host type? + } + + // Dispatch the original function. + builder.emitDispatchKernelInst( + builder.getVoidType(), + func, + blockSizeParam, + gridSizeParam, + mappedParams.getCount(), + mappedParams.getBuffer()); + + builder.emitReturn(); + + IRFuncType* hostFuncType = builder.getFuncType(hostParamTypes, builder.getVoidType()); + hostFunc->setFullType(hostFuncType); + + // Add a torch entry point decoration to the host function to mark + // for further processing. + // + if (auto pybindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>()) + { + // Mark for further processing of torch-specific insts. + builder.addTorchEntryPointDecoration(hostFunc, pybindCudaHint->getFunctionName()); + // Mark for host-side emit logic. + builder.addCudaHostDecoration(hostFunc); + // Keep alive. This method will be accessed externally. + builder.addPublicDecoration(hostFunc); + builder.addKeepAliveDecoration(hostFunc); + } + + if (auto externCppHint = func->findDecoration<IRExternCppDecoration>()) + { + // Transfer to the host function. + builder.addExternCppDecoration(hostFunc, externCppHint->getName()); + } + + if (auto exportInfoHint = func->findDecoration<IRAutoPyBindExportInfoDecoration>()) + generateReflectionFunc(&builder, func, hostFunc); + + return hostFunc; +} + void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) { List<IRFunc*> workList; List<IRFunc*> cudaKernels; + List<IRFunc*> autoBindRequests; + List<IRType*> typesToExport; for (auto globalInst : module->getGlobalInsts()) { - auto func = as<IRFunc>(globalInst); - if (!func) - continue; - if (func->findDecoration<IRTorchEntryPointDecoration>()) - { - workList.add(func); - } - else if (func->findDecoration<IRCudaKernelDecoration>()) + if (auto func = as<IRFunc>(globalInst)) { - cudaKernels.add(func); + if (func->findDecoration<IRAutoPyBindCudaDecoration>()) + { + autoBindRequests.add(func); + } + if (func->findDecoration<IRTorchEntryPointDecoration>()) + { + workList.add(func); + } + else if (func->findDecoration<IRCudaKernelDecoration>()) + { + cudaKernels.add(func); + } + else + { + // Remove all other export decorations if this is not a cuda host func. + if (auto decor = func->findDecoration<IRPublicDecoration>()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration<IRHLSLExportDecoration>()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration<IRKeepAliveDecoration>()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration<IRDllExportDecoration>()) + decor->removeAndDeallocate(); + } } - else + } + + // Generate CUDA wrappers for all functions that have the auto-bind decoration. + for (auto func : autoBindRequests) + { + if (auto hostFunc = generateCUDAWrapperForFunc(func, sink)) { - // Remove all other export decorations if this is not a cuda host func. - if (auto decor = func->findDecoration<IRPublicDecoration>()) - decor->removeAndDeallocate(); - if (auto decor = func->findDecoration<IRHLSLExportDecoration>()) - decor->removeAndDeallocate(); - if (auto decor = func->findDecoration<IRKeepAliveDecoration>()) - decor->removeAndDeallocate(); - if (auto decor = func->findDecoration<IRDllExportDecoration>()) - decor->removeAndDeallocate(); + // Add generated wrapper to worklist for python bindings. + workList.add(hostFunc); } } @@ -355,6 +814,20 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) block = nextBlock; } } + + for (auto globalInst : module->getGlobalInsts()) + { + if (auto type = as<IRType>(globalInst)) + { + if (type->findDecoration<IRPyExportDecoration>()) + { + typesToExport.add(type); + } + } + } + + for (auto type : typesToExport) + generateReflectionForType(type, sink); } // Remove all [TorchEntryPoint] functions when emitting CUDA source. @@ -372,4 +845,164 @@ void removeTorchKernels(IRModule* module) inst->removeAndDeallocate(); } +void handleAutoBindNames(IRModule* module) +{ + // We need to rewrite extern-cpp names for functions that have an auto-bind decoration. + // since the name needs to be used for the host function. + // + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->findDecoration<IRAutoPyBindCudaDecoration>()) + { + // Find an extern decoration on the original function, and append a prefix to the name. + if (auto externCppHint = globalInst->findDecoration<IRExternCppDecoration>()) + { + IRBuilder builder(module); + + // Change the name of the original function. + StringBuilder nameBuilder; + nameBuilder << "__kernel__" << externCppHint->getName(); + externCppHint->removeAndDeallocate(); + builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice()); + } + } + } +} + +void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + for (auto globalInst : module->getGlobalInsts()) + { + if (!as<IRFunc>(globalInst)) + continue; + + // Look for methods marked with auto-bind and are differentiable. + if (globalInst->findDecoration<IRAutoPyBindCudaDecoration>()) + { + if(globalInst->findDecoration<IRForwardDifferentiableDecoration>() || + globalInst->findDecoration<IRBackwardDifferentiableDecoration>()) + { + // We'll generate a wrapper for this method that calls fwd_diff(fn) + // but an important thing to note is that we won't actually employ the usual + // differentiable typing rules. We'll assume none of the parameters are + // differentiable & throw a warning if some are. This is because, for the auto-binding + // scenario, we expect to only see tensor types, and their differentiation is handled using + // tensor _pair_ types which handle the differentiable loads/stores through custom derivatives + // + // For now, the user is expected to explicitly use the tensor pair types, so we will simply copy over + // the original function's signature. + // In the future, when we update the type system to be able to specify the corresponding pair type, + // we can update this logic. + // + + // Create a new wrapper function. + IRBuilder builder(module); + auto func = cast<IRFunc>(globalInst); + auto wrapperFunc = builder.createFunc(); + builder.setInsertInto(wrapperFunc); + builder.emitBlock(); + + // Clone the parameter list. + List<IRInst*> params; + for (auto param : func->getFirstBlock()->getParams()) + { + params.add(builder.emitParam(param->getFullType())); + } + + wrapperFunc->setFullType(func->getFullType()); + + auto fwdDiffFunc = builder.emitForwardDifferentiateInst(func->getFullType(), func); + auto fwdDiffCall = builder.emitCallInst( + func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer()); + + builder.emitReturn(fwdDiffCall); + + // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well. + if (auto kernelHint = func->findDecoration<IRCudaKernelDecoration>()) + builder.addCudaKernelDecoration(wrapperFunc); + + // Add an auto-pybind-cuda decoration to the wrapper function to further generate the + // host-side binding for the derivative kernel. + // + { + auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>(); + StringBuilder nameBuilder; + nameBuilder << autoPyBindCudaHint->getFunctionName() << "_fwd_diff"; + builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); + } + + // Build a name for the wrapper function: <original_name>_fwd_diff + if (auto externCppHint = func->findDecoration<IRExternCppDecoration>()) + { + StringBuilder nameBuilder; + nameBuilder << externCppHint->getName() << "_fwd_diff"; + builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); + } + + builder.addPublicDecoration(wrapperFunc); + builder.addKeepAliveDecoration(wrapperFunc); + + builder.addCudaKernelForwardDerivativeDecoration(func, wrapperFunc); + } + + if (globalInst->findDecoration<IRBackwardDifferentiableDecoration>()) + { + // The reasoning for the reverse-mode is the same as the forward-mode version + // (see above) + // + + // Create a new wrapper function. + IRBuilder builder(module); + auto func = cast<IRFunc>(globalInst); + auto wrapperFunc = builder.createFunc(); + builder.setInsertInto(wrapperFunc); + builder.emitBlock(); + + // Clone the parameter list. + List<IRInst*> params; + for (auto param : func->getFirstBlock()->getParams()) + { + params.add(builder.emitParam(param->getFullType())); + } + + wrapperFunc->setFullType(func->getFullType()); + + auto fwdDiffFunc = builder.emitBackwardDifferentiateInst(func->getFullType(), func); + auto fwdDiffCall = builder.emitCallInst( + func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer()); + + builder.emitReturn(fwdDiffCall); + + // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well. + if (auto kernelHint = func->findDecoration<IRCudaKernelDecoration>()) + builder.addCudaKernelDecoration(wrapperFunc); + + // Add an auto-pybind-cuda decoration to the wrapper function to further generate the + // host-side binding for the derivative kernel. + // + { + auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>(); + StringBuilder nameBuilder; + nameBuilder << autoPyBindCudaHint->getFunctionName() << "_bwd_diff"; + builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); + } + + // Build a name for the wrapper function: <original_name>_bwd_diff + if (auto externCppHint = func->findDecoration<IRExternCppDecoration>()) + { + StringBuilder nameBuilder; + nameBuilder << externCppHint->getName() << "_bwd_diff"; + builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice()); + } + + builder.addPublicDecoration(wrapperFunc); + builder.addKeepAliveDecoration(wrapperFunc); + + builder.addCudaKernelBackwardDerivativeDecoration(func, wrapperFunc); + } + } + } +} + } diff --git a/source/slang/slang-ir-pytorch-cpp-binding.h b/source/slang/slang-ir-pytorch-cpp-binding.h index c35b6a8eb..dd7dcc9a4 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.h +++ b/source/slang/slang-ir-pytorch-cpp-binding.h @@ -7,6 +7,8 @@ class DiagnosticSink; void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink); void removeTorchKernels(IRModule* module); +void handleAutoBindNames(IRModule* module); +void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 8d1a86d4d..4f64087d8 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1353,6 +1353,17 @@ static void addLinkageDecoration( builder->addPublicDecoration(inst); builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); } + else if (as<AutoPyBindCudaAttribute>(modifier)) + { + builder->addAutoPyBindCudaDecoration(inst, decl->getName()->text.getUnownedSlice()); + builder->addAutoPyBindExportInfoDecoration(inst); + } + else if (auto pyExportModifier = as<PyExportAttribute>(modifier)) + { + builder->addPyExportDecoration(inst, pyExportModifier->name.getLength() + ? pyExportModifier->name.getUnownedSlice() + : decl->getName()->text.getUnownedSlice()); + } else if (as<KnownBuiltinAttribute>(modifier)) { // We add this to the internal instruction, like other name-like |
