summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang6
-rw-r--r--source/slang/diff.meta.slang292
-rw-r--r--source/slang/slang-ast-modifier.h12
-rw-r--r--source/slang/slang-check-modifier.cpp13
-rw-r--r--source/slang/slang-check-type.cpp37
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-emit-c-like.cpp6
-rw-r--r--source/slang/slang-emit.cpp6
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp10
-rw-r--r--source/slang/slang-ir-autodiff-rev.h42
-rw-r--r--source/slang/slang-ir-autodiff.cpp1
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp3
-rw-r--r--source/slang/slang-ir-inst-defs.h5
-rw-r--r--source/slang/slang-ir-insts.h68
-rw-r--r--source/slang/slang-ir-legalize-types.cpp7
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp673
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.h2
-rw-r--r--source/slang/slang-lower-to-ir.cpp11
18 files changed, 1163 insertions, 33 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 956a5b29a..e989e4ffa 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2263,6 +2263,12 @@ attribute_syntax [CudaHost] : CudaHostAttribute;
__attributeTarget(FuncDecl)
attribute_syntax [CudaKernel] : CudaKernelAttribute;
+__attributeTarget(FuncDecl)
+attribute_syntax[AutoPyBindCUDA] : AutoPyBindCudaAttribute;
+
+__attributeTarget(AggTypeDecl)
+attribute_syntax [PyExport(name: String)] : PyExportAttribute;
+
__attributeTarget(InterfaceDecl)
attribute_syntax [COM(guid: String)] : ComInterfaceAttribute;
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 423b6bfd0..495b6b989 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -264,6 +264,298 @@ extension TensorView<float>
void InterlockedCompareExchange(vector<uint, N> index, float compare, float val);
}
+interface IDiffTensorWrapper
+{
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward(uint offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_2(uint2 offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_3(uint3 offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_4(uint4 offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward(uint offset, T dOut);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_2(uint2 offset, T dOut);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_3(uint3 offset, T dOut);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_4(uint4 offset, T dOut);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward(uint offset, T dx);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_2(uint2 offset, T dx);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_3(uint3 offset, T dx);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_4(uint4 offset, T dx);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward(uint offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_2(uint2 offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_3(uint3 offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_4(uint4 offset);
+};
+
+struct AtomicAdd : IDiffTensorWrapper
+{
+ TensorView<float> diff;
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward(uint i)
+ {
+ return __realCast<T, float>(diff.load(i));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_2(uint2 i)
+ {
+ return __realCast<T, float>(diff.load(i.x, i.y));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_3(uint3 i)
+ {
+ return __realCast<T, float>(diff.load(i.x, i.y, i.z));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_4(uint4 i)
+ {
+ return __realCast<T, float>(diff.load(i.x, i.y, i.z, i.w));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward(uint i, T dOut)
+ {
+ float oldVal;
+ diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_2(uint2 i, T dOut)
+ {
+ float oldVal;
+ diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_3(uint3 i, T dOut)
+ {
+ float oldVal;
+ diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_4(uint4 i, T dOut)
+ {
+ float oldVal;
+ diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward(uint i, T dx)
+ {
+ diff.store(i, __realCast<float, T>(dx));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_2(uint2 i, T dx)
+ {
+ diff.store(i.x, i.y, __realCast<float, T>(dx));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_3(uint3 i, T dx)
+ {
+ diff.store(i.x, i.y, i.z, __realCast<float, T>(dx));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_4(uint4 i, T dx)
+ {
+ diff.store(i.x, i.y, i.z, i.w, __realCast<float, T>(dx));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward(uint i)
+ {
+ float oldVal;
+ diff.InterlockedExchange(i, (float)0, oldVal);
+ return __realCast<T, float>(oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_2(uint2 i)
+ {
+ float oldVal;
+ diff.InterlockedExchange(i, (float)0, oldVal);
+ return __realCast<T, float>(oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_3(uint3 i)
+ {
+ float oldVal;
+ diff.InterlockedExchange(i, (float)0, oldVal);
+ return __realCast<T, float>(oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_4(uint4 i)
+ {
+ float oldVal;
+ diff.InterlockedExchange(i, (float)0, oldVal);
+ return __realCast<T, float>(oldVal);
+ }
+};
+
+__generic<T: __BuiltinFloatingPointType = float, A : IDiffTensorWrapper = AtomicAdd>
+struct DiffTensorView
+{
+ TensorView<T> primal;
+ A diff;
+
+ uint size(uint i)
+ {
+ return primal.size(i);
+ }
+
+ [BackwardDerivative(load_backward)]
+ [ForwardDerivative(load_forward)]
+ T load(uint i) { return primal.load(i); }
+
+ [BackwardDerivative(load_backward)]
+ [ForwardDerivative(load_forward)]
+ T load(uint2 i) { return primal.load(i.x, i.y); }
+
+ [BackwardDerivative(load_backward)]
+ [ForwardDerivative(load_forward)]
+ T load(uint3 i) { return primal.load(i.x, i.y, i.z); }
+
+ [BackwardDerivative(load_backward)]
+ [ForwardDerivative(load_forward)]
+ T load(uint4 i) { return primal.load(i.x, i.y, i.z, i.w); }
+
+ DifferentialPair<T> load_forward(uint x)
+ {
+ return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T>(x)));
+ }
+
+ DifferentialPair<T> load_forward(uint2 x)
+ {
+ return diffPair(primal.load(x.x, x.y), reinterpret<T.Differential, T>(diff.load_forward_2<T>(x)));
+ }
+
+ DifferentialPair<T> load_forward(uint3 x)
+ {
+ return diffPair(primal.load(x.x, x.y, x.z), reinterpret<T.Differential, T>(diff.load_forward_3<T>(x)));
+ }
+
+ DifferentialPair<T> load_forward(uint4 x)
+ {
+ return diffPair(primal.load(x.x, x.y, x.z, x.w), reinterpret<T.Differential, T>(diff.load_forward_4<T>(x)));
+ }
+
+ void load_backward(uint x, T.Differential dOut)
+ {
+ diff.load_backward<T>(x, reinterpret<T, T.Differential>(dOut));
+ }
+
+ void load_backward(uint2 x, T.Differential dOut)
+ {
+ diff.load_backward_2<T>(x, reinterpret<T, T.Differential>(dOut));
+ }
+
+ void load_backward(uint3 x, T.Differential dOut)
+ {
+ diff.load_backward_3<T>(x, reinterpret<T, T.Differential>(dOut));
+ }
+
+ void load_backward(uint4 x, T.Differential dOut)
+ {
+ diff.load_backward_4<T>(x, reinterpret<T, T.Differential>(dOut));
+ }
+
+ [BackwardDerivative(store_backward)]
+ [ForwardDerivative(store_forward)]
+ void store(uint x, T val) { primal.store(x, val); }
+
+ [BackwardDerivative(store_backward)]
+ [ForwardDerivative(store_forward)]
+ void store(uint2 x, T val) { primal.store(x.x, x.y, val); }
+
+ [BackwardDerivative(store_backward)]
+ [ForwardDerivative(store_forward)]
+ void store(uint3 x, T val) { primal.store(x.x, x.y, x.z, val); }
+
+ [BackwardDerivative(store_backward)]
+ [ForwardDerivative(store_forward)]
+ void store(uint4 x, T val) { primal.store(x.x, x.y, x.z, x.w, val); }
+
+ void store_forward(uint x, DifferentialPair<T> dpval)
+ {
+ primal.store(x, dpval.p);
+ diff.store_forward<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ }
+
+ void store_forward(uint2 x, DifferentialPair<T> dpval)
+ {
+ primal.store(x.x, x.y, dpval.p);
+ diff.store_forward_2<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ }
+
+ void store_forward(uint3 x, DifferentialPair<T> dpval)
+ {
+ primal.store(x.x, x.y, x.z, dpval.p);
+ diff.store_forward_3<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ }
+
+ void store_forward(uint4 x, DifferentialPair<T> dpval)
+ {
+ primal.store(x.x, x.y, x.z, x.w, dpval.p);
+ diff.store_forward_4<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ }
+
+ void store_backward(uint x, inout DifferentialPair<T> dpval)
+ {
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T>(x)));
+ }
+
+ void store_backward(uint2 x, inout DifferentialPair<T> dpval)
+ {
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_2<T>(x)));
+ }
+
+ void store_backward(uint3 x, inout DifferentialPair<T> dpval)
+ {
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_3<T>(x)));
+ }
+
+ void store_backward(uint4 x, inout DifferentialPair<T> dpval)
+ {
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_4<T>(x)));
+ }
+};
+
/// Represents the handle of a Torch tensor object.
__generic<T>
__intrinsic_type($(kIROp_TorchTensorType))
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index d70651636..af5823db4 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1134,6 +1134,18 @@ class CudaHostAttribute : public Attribute
SLANG_AST_CLASS(CudaHostAttribute)
};
+class AutoPyBindCudaAttribute : public Attribute
+{
+ SLANG_AST_CLASS(AutoPyBindCudaAttribute)
+};
+
+class PyExportAttribute : public Attribute
+{
+ SLANG_AST_CLASS(PyExportAttribute)
+
+ String name;
+};
+
class PreferRecomputeAttribute : public Attribute
{
SLANG_AST_CLASS(PreferRecomputeAttribute)
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 53283fbe1..e8ae28c04 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -757,6 +757,19 @@ namespace Slang
knownBuiltinAttr->name = name;
}
+ else if (auto pyExportAttr = as<PyExportAttribute>(attr))
+ {
+ // Check name string.
+ SLANG_ASSERT(attr->args.getCount() == 1);
+
+ String name;
+ if(!checkLiteralStringVal(attr->args[0], &name))
+ {
+ return false;
+ }
+
+ pyExportAttr->name = name;
+ }
else
{
if(attr->args.getCount() == 0)
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index d5d3e5a5d..5967da8b0 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -273,7 +273,8 @@ namespace Slang
auto genericDeclRef = genericDeclRefType->getDeclRef();
ensureDecl(genericDeclRef, DeclCheckState::CanSpecializeGeneric);
- List<Expr*> args;
+ List<Val*> args;
+ List<Val*> witnessArgs;
for (Decl* member : genericDeclRef.getDecl()->members)
{
if (auto typeParam = as<GenericTypeParamDecl>(member))
@@ -290,7 +291,7 @@ namespace Slang
// TODO: this is one place where syntax should get cloned!
if (outProperType)
- args.add(typeParam->initType.exp);
+ args.add(ExtractGenericArgVal(typeParam->initType.exp));
}
else if (auto valParam = as<GenericValueParamDecl>(member))
{
@@ -305,14 +306,42 @@ namespace Slang
}
// TODO: this is one place where syntax should get cloned!
if (outProperType)
- args.add(valParam->initExpr);
+ args.add(ExtractGenericArgVal(valParam->initExpr));
+ }
+ else if (auto constraintParam = as<GenericTypeConstraintDecl>(member))
+ {
+ auto genericParam = as<DeclRefType>(constraintParam->sub.type)->getDeclRef();
+ if (!genericParam)
+ return false;
+ auto genericTypeParamDecl = as<GenericTypeParamDecl>(genericParam.getDecl());
+ if (!genericTypeParamDecl)
+ return false;
+ auto defaultType = CheckProperType(genericTypeParamDecl->initType);
+ auto witness = tryGetSubtypeWitness(defaultType, CheckProperType(constraintParam->sup));
+ if (!witness)
+ {
+ // diagnose
+ getSink()->diagnose(
+ genericTypeParamDecl->initType.exp,
+ Diagnostics::typeArgumentDoesNotConformToInterface,
+ defaultType,
+ constraintParam->sup);
+
+ SLANG_UNEXPECTED("default type argument does not conform to interface");
+ return false;
+ }
+ witnessArgs.add(witness);
}
else
{
// ignore non-parameter members
}
}
- result = InstantiateGenericType(genericDeclRef, args);
+ // Combine args and witnessArgs
+ args.addRange(witnessArgs);
+
+ result = DeclRefType::create(getASTBuilder(),
+ getASTBuilder()->getGenericAppDeclRef(genericDeclRef, args.getArrayView()));
}
// default case: we expect this to already be a proper type
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 5a28c8188..14f45e802 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -726,6 +726,8 @@ DIAGNOSTIC(54004, Warning, unnecessaryHLSLMeshOutputModifier, "Unnecessary HLSL
DIAGNOSTIC(55101, Error, invalidTorchKernelReturnType, "'$0' is not a valid return type for a pytorch kernel function.")
DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid parameter type for a pytorch kernel function.")
+DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'")
+
//
// 8xxxx - Issues specific to a particular library/technology/platform/etc.
//
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index fa6fd2b43..a74459954 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -2286,7 +2286,11 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
auto prec = getInfo(EmitOp::Postfix);
needClose = maybeEmitParens(outerPrec, prec);
emitOperand(inst->getOperand(0), leftSide(outerPrec, prec));
- m_writer->emit("->getBuffer()");
+ if (as<IRPtrTypeBase>(inst->getOperand(0)->getDataType()))
+ m_writer->emit("->");
+ else
+ m_writer->emit(".");
+ m_writer->emit("getBuffer()");
break;
}
case kIROp_MakeString:
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index c77f2a6ce..98c9c9803 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -45,6 +45,7 @@
#include "slang-ir-legalize-vector-types.h"
#include "slang-ir-metadata.h"
#include "slang-ir-optix-entry-point-uniforms.h"
+#include "slang-ir-pytorch-cpp-binding.h"
#include "slang-ir-restructure.h"
#include "slang-ir-restructure-scoping.h"
#include "slang-ir-sccp.h"
@@ -369,6 +370,9 @@ Result linkAndOptimizeIR(
// being passed to saturated_cooperation
fuseCallsToSaturatedCooperation(irModule);
+ // Generate any requested derivative wrappers
+ generateDerivativeWrappers(irModule, sink);
+
// Next, we need to ensure that the code we emit for
// the target doesn't contain any operations that would
// be illegal on the target platform. For example,
@@ -444,9 +448,11 @@ Result linkAndOptimizeIR(
{
case CodeGenTarget::PyTorchCppBinding:
generatePyTorchCppBinding(irModule, sink);
+ handleAutoBindNames(irModule);
break;
case CodeGenTarget::CUDASource:
removeTorchKernels(irModule);
+ handleAutoBindNames(irModule);
break;
default:
break;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 681c69cd3..335b6572e 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -332,12 +332,10 @@ namespace Slang
as<IRFuncType>(origFunc->getFullType()));
diffFunc->setFullType(diffFuncType);
- if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>())
+ if (origFunc->findDecoration<IRNameHintDecoration>())
{
- auto originalName = nameHint->getName();
- StringBuilder newNameSb;
- newNameSb << "s_bwd_" << originalName;
- builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
+ auto newName = this->getTranscribedFuncName(&builder, origFunc);
+ builder.addNameHintDecoration(diffFunc, newName);
}
// Transfer checkpoint hint decorations
@@ -492,6 +490,8 @@ namespace Slang
builder.emitCallInst(builder.getVoidType(), propagateFunc, propagateArgs);
builder.emitReturn();
+
+ addTranscribedFuncDecoration(builder, origFunc, cast<IRFunc>(header.differential));
return header;
}
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index 15d558c22..68cb4e0c9 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -94,6 +94,9 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
// Transcribe a function definition.
virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) = 0;
+ // Get transcribed function name from original name.
+ virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) = 0;
+
// Splits and transpose the parameter block.
// After this operation, the parameter block will contain parameters for both the future
// primal func and the future propagate func.
@@ -160,6 +163,19 @@ struct BackwardDiffPrimalTranscriber : BackwardDiffTranscriberBase
{
return kIROp_BackwardDerivativePrimalDecoration;
}
+ virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) override
+ {
+ if (auto nameHint = func->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder sbuilder;
+ sbuilder << "s_primal_ctx_" << nameHint->getName();
+ return builder->getStringValue(sbuilder.getUnownedSlice());
+ }
+ else
+ {
+ return builder->getStringValue(String("s_primal_ctx_anonymous").getUnownedSlice());
+ }
+ }
};
struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase
@@ -196,6 +212,19 @@ struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase
{
return kIROp_BackwardDerivativePropagateDecoration;
}
+ virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) override
+ {
+ if (auto nameHint = func->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder sbuilder;
+ sbuilder << "s_bwd_prop_" << nameHint->getName();
+ return builder->getStringValue(sbuilder.getUnownedSlice());
+ }
+ else
+ {
+ return builder->getStringValue(String("s_bwd_prop_anonymous").getUnownedSlice());
+ }
+ }
};
// A backward derivative function combines both primal + propagate functions and accepts no
@@ -235,6 +264,19 @@ struct BackwardDiffTranscriber : BackwardDiffTranscriberBase
{
builder->addBackwardDerivativeDecoration(inst, diffFunc);
}
+ virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) override
+ {
+ if (auto nameHint = func->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder sbuilder;
+ sbuilder << "s_bwd_" << nameHint->getName();
+ return builder->getStringValue(sbuilder.getUnownedSlice());
+ }
+ else
+ {
+ return builder->getStringValue(String("s_bwd_anonymous").getUnownedSlice());
+ }
+ }
};
}
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 4c5608341..5b90e2711 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1039,6 +1039,7 @@ void stripDerivativeDecorations(IRInst* inst)
}
}
+
void stripAutoDiffDecorationsFromChildren(IRInst* parent)
{
for (auto inst : parent->getChildren())
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 6171d9a75..8f01a574f 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -295,9 +295,6 @@ public:
}
}
- if (differentiableOutputs == 0)
- sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveOutput);
-
DifferentiableLevel requiredDiffLevel = DifferentiableLevel::Forward;
if (isBackwardDifferentiableFunc(funcInst))
requiredDiffLevel = DifferentiableLevel::Backward;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 552a8af6d..b44b7b5d9 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -719,6 +719,11 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(CudaKernelDecoration, CudaKernel, 0, 0)
INST(CudaHostDecoration, CudaHost, 0, 0)
INST(TorchEntryPointDecoration, TorchEntryPoint, 0, 0)
+ INST(AutoPyBindCudaDecoration, AutoPyBindCUDA, 0, 0)
+ INST(CudaKernelForwardDerivativeDecoration, CudaKernelFwdDiffRef, 0, 0)
+ INST(CudaKernelBackwardDerivativeDecoration, CudaKernelBwdDiffRef, 0, 0)
+ INST(AutoPyBindExportInfoDecoration, PyBindExportFuncInfo, 0, 0)
+ INST(PyExportDecoration, PyExportDecoration, 0, 0)
/// Used to mark parameters that are moved from entry point parameters to global params as coming from the entry point.
INST(EntryPointParamDecoration, entryPointParam, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index b88df524a..63555c08d 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -460,6 +460,22 @@ struct IREntryPointDecoration : IRDecoration
IR_SIMPLE_DECORATION(CudaHostDecoration)
IR_SIMPLE_DECORATION(CudaKernelDecoration)
+struct IRCudaKernelForwardDerivativeDecoration : IRDecoration
+{
+ enum { kOp = kIROp_CudaKernelForwardDerivativeDecoration };
+ IR_LEAF_ISA(CudaKernelForwardDerivativeDecoration)
+
+ IRInst* getForwardDerivativeFunc() { return getOperand(0); }
+};
+
+struct IRCudaKernelBackwardDerivativeDecoration : IRDecoration
+{
+ enum { kOp = kIROp_CudaKernelBackwardDerivativeDecoration };
+ IR_LEAF_ISA(CudaKernelBackwardDerivativeDecoration)
+
+ IRInst* getBackwardDerivativeFunc() { return getOperand(0); }
+};
+
struct IRGeometryInputPrimitiveTypeDecoration: IRDecoration
{
IR_PARENT_ISA(GeometryInputPrimitiveTypeDecoration)
@@ -566,6 +582,33 @@ struct IRTorchEntryPointDecoration : IRDecoration
UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); }
};
+struct IRAutoPyBindCudaDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_AutoPyBindCudaDecoration
+ };
+ IR_LEAF_ISA(AutoPyBindCudaDecoration)
+
+ IRStringLit* getFunctionNameOperand() { return cast<IRStringLit>(getOperand(0)); }
+ UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); }
+};
+
+IR_SIMPLE_DECORATION(AutoPyBindExportInfoDecoration)
+
+struct IRPyExportDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_PyExportDecoration
+ };
+ IR_LEAF_ISA(PyExportDecoration)
+
+ IRStringLit* getExportNameOperand() { return cast<IRStringLit>(getOperand(0)); }
+ UnownedStringSlice getExportName() { return getExportNameOperand()->getStringSlice(); }
+};
+
+
struct IRKnownBuiltinDecoration : IRDecoration
{
enum
@@ -4368,6 +4411,16 @@ public:
addDecoration(value, kIROp_TorchEntryPointDecoration, getStringValue(functionName));
}
+ void addAutoPyBindCudaDecoration(IRInst* value, UnownedStringSlice const& functionName)
+ {
+ addDecoration(value, kIROp_AutoPyBindCudaDecoration, getStringValue(functionName));
+ }
+
+ void addPyExportDecoration(IRInst* value, UnownedStringSlice const& exportName)
+ {
+ addDecoration(value, kIROp_PyExportDecoration, getStringValue(exportName));
+ }
+
void addCudaDeviceExportDecoration(IRInst* value, UnownedStringSlice const& functionName)
{
addDecoration(value, kIROp_CudaDeviceExportDecoration, getStringValue(functionName));
@@ -4383,6 +4436,21 @@ public:
addDecoration(value, kIROp_CudaKernelDecoration);
}
+ void addCudaKernelForwardDerivativeDecoration(IRInst* value, IRInst* func)
+ {
+ addDecoration(value, kIROp_CudaKernelForwardDerivativeDecoration, func);
+ }
+
+ void addCudaKernelBackwardDerivativeDecoration(IRInst* value, IRInst* func)
+ {
+ addDecoration(value, kIROp_CudaKernelBackwardDerivativeDecoration, func);
+ }
+
+ void addAutoPyBindExportInfoDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_AutoPyBindExportInfoDecoration);
+ }
+
void addEntryPointDecoration(IRInst* value, Profile profile, UnownedStringSlice const& name, UnownedStringSlice const& moduleName)
{
IRInst* operands[] = { getIntValue(getIntType(), profile.raw), getStringValue(name), getStringValue(moduleName) };
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index cb1cf3db3..4b519b065 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -3659,7 +3659,14 @@ struct IRTypeLegalizationPass
// * `i` is a child of `inst`.
//
if (legalVal.flavor == LegalVal::Flavor::simple)
+ {
+ // The resulting inst may be different from the one we added to the
+ // worklist, so ensure that the appropriate flags are set.
+ //
+ setHasBeenAddedOrProcessed(legalVal.irValue);
+
inst = legalVal.irValue;
+ }
for( auto use = inst->firstUse; use; use = use->nextUse )
{
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index d59d57474..c0adef436 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -2,11 +2,14 @@
#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-diagnostics.h"
+#include "slang-ir-autodiff.h"
namespace Slang
{
// Convert a type to a target tuple type.
-static IRType* translateToTupleType(IRBuilder& builder, IRType* type)
+static IRType* translateToTupleType(
+ IRBuilder& builder,
+ IRType* type)
{
if (as<IRVoidType>(type))
return type;
@@ -312,34 +315,490 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
inst->removeAndDeallocate();
}
+IRType* translateToHostType(IRBuilder* builder, IRType* type, DiagnosticSink* sink = nullptr)
+{
+ if (as<IRBasicType>(type))
+ return type;
+
+ switch (type->getOp())
+ {
+ case kIROp_TensorViewType:
+ return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType());
+
+ case kIROp_StructType:
+ {
+ // Create a new struct type with translated fields.
+ List<IRType*> fieldTypes;
+ for (auto field : as<IRStructType>(type)->getFields())
+ {
+ fieldTypes.add(translateToHostType(builder, field->getFieldType()));
+ }
+ auto hostStructType = builder->createStructType();
+
+ // Add fields to the struct.
+ for (UInt i = 0; i < (UInt)fieldTypes.getCount(); i++)
+ {
+ builder->createStructField(hostStructType, builder->createStructKey(), fieldTypes[i]);
+ }
+
+ return hostStructType;
+ }
+ default:
+ break;
+ }
+
+ if (sink)
+ sink->diagnose(type->sourceLoc, Diagnostics::unableToAutoMapCUDATypeToHostType, type);
+ return nullptr;
+}
+
+IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaType, IRInst* inst)
+{
+ if (as<IRBasicType>(hostType) && as<IRBasicType>(cudaType))
+ return inst;
+
+ switch (cudaType->getOp())
+ {
+ case kIROp_TensorViewType:
+ return builder->emitMakeTensorView(cudaType, inst);
+
+ case kIROp_StructType:
+ {
+ auto cudaStructType = cast<IRStructType>(cudaType);
+ auto hostStructType = cast<IRStructType>(hostType);
+
+ List<IRStructField*> cudaFields;
+ for (auto field : cudaStructType->getFields())
+ cudaFields.add(field);
+
+ List<IRStructField*> hostFields;
+ for (auto field : hostStructType->getFields())
+ hostFields.add(field);
+
+ List<IRInst*> resultFields;
+ for (auto ii = 0; ii < cudaFields.getCount(); ii++)
+ {
+ auto cudaField = cudaFields[ii];
+ auto hostField = hostFields[ii];
+ auto cudaFieldType = cudaField->getFieldType();
+ auto hostFieldType = hostField->getFieldType();
+ auto castedField = castHostToCUDAType(
+ builder,
+ hostFieldType,
+ cudaFieldType,
+ builder->emitFieldExtract(hostFieldType, inst, hostField->getKey()));
+
+ SLANG_RELEASE_ASSERT(castedField);
+ resultFields.add(castedField);
+ }
+
+ return builder->emitMakeStruct(cudaType, (UInt)resultFields.getCount(), resultFields.getBuffer());
+ }
+
+ default:
+ break;
+ }
+
+ // If translateToHostType worked correctly, we shouldn't get here.
+ SLANG_UNREACHABLE("unhandled type");
+}
+
+void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* hostFunc)
+{
+ // Given a func with torch binding, we'll generate a reflection function that returns
+ // a tuple where the first element is another tuple of parameter names, the second
+ // element is a string containing the name of the fwd-diff function, and the third
+ // element is a string containing the name of the bwd-diff function.
+ //
+
+ // Create a new function.
+ auto reflectionFunc = builder->createFunc();
+ builder->setInsertInto(reflectionFunc);
+ builder->emitBlock();
+
+ // Go through func & generate a tuple of parameter names.
+ List<IRInst*> paramNames;
+ List<IRInst*> paramTypeNames;
+ UIndex paramCount = 0;
+ for (auto param : hostFunc->getFirstBlock()->getParams())
+ {
+ if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
+ {
+ paramNames.add(builder->emitGetNativeString(builder->getStringValue(nameHint->getName())));
+ }
+ else
+ {
+ StringBuilder argNameBuilder;
+ argNameBuilder << "param";
+ argNameBuilder << paramCount;
+
+ paramNames.add(builder->emitGetNativeString(builder->getStringValue(argNameBuilder.getUnownedSlice())));
+ }
+ paramCount++;
+ }
+
+ for (auto param : kernelFunc->getParams())
+ {
+ // Check for py-export decoration.
+ if (auto pyExportHint = param->getDataType()->findDecoration<IRPyExportDecoration>())
+ {
+ paramTypeNames.add(
+ builder->emitGetNativeString(
+ builder->getStringValue(
+ pyExportHint->getExportName())));
+ }
+ else
+ {
+ paramTypeNames.add(
+ builder->emitGetNativeString(
+ builder->getStringValue(
+ UnownedStringSlice(""))));
+ }
+ }
+
+ // Create a target-tuple-type for the names
+ auto paramNamesTupleType = builder->getTargetTupleType(
+ (UInt)paramNames.getCount(),
+ List<IRType*>().makeRepeated(builder->getNativeStringType(), paramNames.getCount()).getBuffer());
+ auto paramNamesTuple = builder->emitMakeTargetTuple(paramNamesTupleType, paramNames.getCount(), paramNames.getBuffer());
+
+ // Create a target-tuple-type for the type names
+ auto paramTypeNamesTupleType = builder->getTargetTupleType(
+ (UInt)paramTypeNames.getCount(),
+ List<IRType*>().makeRepeated(builder->getNativeStringType(), paramTypeNames.getCount()).getBuffer());
+ auto paramTypeNamesTuple = builder->emitMakeTargetTuple(paramTypeNamesTupleType, paramTypeNames.getCount(), paramTypeNames.getBuffer());
+
+ // Find the fwd-diff function name (blank string indicates no fwd-diff)
+ IRInst* fwdDiffName = builder->getStringValue(UnownedStringSlice(""));
+ if (auto fwdDiffHint = kernelFunc->findDecoration<IRCudaKernelForwardDerivativeDecoration>())
+ {
+ auto fwdDiffFunc = fwdDiffHint->getForwardDerivativeFunc();
+
+ if (auto fwdDiffFuncExternHint = fwdDiffFunc->findDecoration<IRExternCppDecoration>())
+ {
+ fwdDiffName = builder->emitGetNativeString(builder->getStringValue(fwdDiffFuncExternHint->getName()));
+ }
+ }
+
+ // Find the bwd-diff function name (blank string indicates no bwd-diff)
+ IRInst* bwdDiffName = builder->getStringValue(UnownedStringSlice(""));
+ if (auto bwdDiffHint = kernelFunc->findDecoration<IRCudaKernelBackwardDerivativeDecoration>())
+ {
+ auto bwdDiffFunc = bwdDiffHint->getBackwardDerivativeFunc();
+
+ if (auto bwdDiffFuncExternHint = bwdDiffFunc->findDecoration<IRExternCppDecoration>())
+ {
+ bwdDiffName = builder->emitGetNativeString(builder->getStringValue(bwdDiffFuncExternHint->getName()));
+ }
+ }
+
+ auto stringType = builder->getNativeStringType();
+ auto returnTupleType = builder->getTargetTupleType(
+ 4,
+ List<IRType*>(paramNamesTupleType, paramTypeNamesTupleType, stringType, stringType).getBuffer());
+
+ // Create a target-tuple-type for the names
+ auto returnTupleArgs = List<IRInst*>( paramNamesTuple, paramTypeNamesTuple, fwdDiffName, bwdDiffName );
+ auto returnTuple = builder->emitMakeTargetTuple(
+ returnTupleType,
+ returnTupleArgs.getCount(),
+ returnTupleArgs.getBuffer());
+ builder->emitReturn(returnTuple);
+
+ // Set function type.
+ auto funcType = builder->getFuncType(List<IRType*>(), returnTupleType);
+ reflectionFunc->setFullType(funcType);
+
+ // Set function name.
+ StringBuilder reflFuncExportName;
+ auto hostFuncExportName = hostFunc->findDecoration<IRExternCppDecoration>()->getName();
+ reflFuncExportName << "__funcinfo__" << hostFuncExportName;
+
+ builder->addExternCppDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice());
+ builder->addTorchEntryPointDecoration(reflectionFunc, reflFuncExportName.getUnownedSlice());
+ builder->addPublicDecoration(reflectionFunc);
+ builder->addKeepAliveDecoration(reflectionFunc);
+}
+
+IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, DiagnosticSink* sink, IRType** outType = nullptr)
+{
+ auto typeMap = [&](IRType* t) -> IRType* {
+ if (auto tensorViewType = as<IRTensorViewType>(t))
+ return builder->getTorchTensorType(tensorViewType->getElementType());
+ };
+
+ auto type = translateToHostType(builder, param->getDataType(), sink);
+ if (outType)
+ *outType = type;
+ auto hostParam = builder->emitParam(type);
+ // Add a namehint to the param by appending the suffix "_host".
+ if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
+ {
+ builder->addNameHintDecoration(hostParam, nameHint->getName());
+ }
+
+ // Then cast the param to the appropriate type.
+ if (auto castedParam = castHostToCUDAType(builder, type, param->getDataType(), hostParam))
+ return castedParam;
+
+ return nullptr;
+}
+
+void markTypeForPyExport(IRType* type, DiagnosticSink* sink)
+{
+ // If it's a basic type, we're done.
+ if (as<IRBasicType>(type) || as<IRVoidType>(type))
+ return;
+
+ // If it's a struct type, mark for py-export.
+ if (auto structType = as<IRStructType>(type))
+ {
+ IRBuilder builder(structType->getModule());
+
+ // If it already has a py-export decoration, we're done.
+ if (!structType->findDecoration<IRPyExportDecoration>())
+ {
+ // Look for a name hint.
+ UnownedStringSlice nameHint;
+ if (auto nameHintDecoration = structType->findDecoration<IRNameHintDecoration>())
+ nameHint = nameHintDecoration->getName();
+ else
+ {
+ // If there's no name hint, we can't export this type.
+ SLANG_UNEXPECTED("struct marked for export has no name");
+ }
+
+ builder.addPyExportDecoration(structType, nameHint);
+ }
+
+ for (auto field : structType->getFields())
+ {
+ markTypeForPyExport(field->getFieldType(), sink);
+ }
+ return;
+ }
+}
+
+void generateReflectionForType(IRType* type, DiagnosticSink* sink)
+{
+ SLANG_UNUSED(sink);
+ // Emit a function that returns a py::list.
+ // The list will contain the names of all the fields of the type.
+ //
+
+ // TODO: Fix this to avoid emitting the same type reflection multiple times.
+ if (!type->findDecoration<IRPyExportDecoration>())
+ return;
+
+ IRBuilder builder(type->getModule());
+
+ auto reflFunc = builder.createFunc();
+ builder.setInsertInto(reflFunc);
+ builder.emitBlock();
+
+ List<IRInst*> fieldNames;
+ List<IRInst*> fieldTypeNames;
+
+ switch (type->getOp())
+ {
+ case kIROp_StructType:
+ {
+ for (auto field : as<IRStructType>(type)->getFields())
+ {
+ auto structKey = field->getKey();
+ // Look for a name hint.
+ if (auto nameHintDecoration = structKey->findDecoration<IRNameHintDecoration>())
+ fieldNames.add(builder.emitGetNativeString(builder.getStringValue(nameHintDecoration->getName())));
+ else
+ fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
+
+ if (!field->getFieldType()->findDecoration<IRPyExportDecoration>())
+ {
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
+ continue;
+ }
+
+ auto fieldType = field->getFieldType();
+
+ fieldTypeNames.add(
+ builder.emitGetNativeString(
+ builder.getStringValue(fieldType->findDecoration<IRPyExportDecoration>()->getExportName())));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+
+ auto _nameListTupleType = builder.getTargetTupleType(
+ (UInt)fieldNames.getCount(),
+ List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldNames.getCount()).getBuffer());
+ auto nameListTuple = builder.emitMakeTargetTuple(_nameListTupleType, (UInt)fieldNames.getCount(), fieldNames.getBuffer());
+
+ auto _typeNameListTupleType = builder.getTargetTupleType(
+ (UInt)fieldTypeNames.getCount(),
+ List<IRType*>().makeRepeated(builder.getNativeStringType(), fieldTypeNames.getCount()).getBuffer());
+ auto typeNameListTuple = builder.emitMakeTargetTuple(_typeNameListTupleType, (UInt)fieldTypeNames.getCount(), fieldTypeNames.getBuffer());
+
+ auto _nameAndTypeTupleType = builder.getTargetTupleType(2, List<IRType*>(_nameListTupleType, _typeNameListTupleType).getBuffer());
+ auto nameAndTypeTuple = builder.emitMakeTargetTuple(
+ _nameAndTypeTupleType,
+ 2,
+ List<IRInst*>(nameListTuple, typeNameListTuple).getBuffer());
+ builder.emitReturn(nameAndTypeTuple);
+
+ // Set function type.
+ auto funcType = builder.getFuncType(List<IRType*>(), _nameAndTypeTupleType);
+ reflFunc->setFullType(funcType);
+
+ // Set function name.
+ StringBuilder reflFuncExportName;
+ reflFuncExportName << "__typeinfo__" << type->findDecoration<IRPyExportDecoration>()->getExportName();
+
+ builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
+ builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
+ builder.addPublicDecoration(reflFunc);
+ builder.addKeepAliveDecoration(reflFunc);
+}
+
+IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink)
+{
+ // Check that the function has an auto-bind decoration
+ if (!func->findDecoration<IRAutoPyBindCudaDecoration>())
+ return nullptr;
+
+ // We will create a CudaHost function that will call func.
+ // But before that, we need to determine the type of CudaHost.
+ //
+ // To determine the type, first we will append two uint3 parameters to the function.
+ // with the names "__blockSize" and "__gridSize", these will serve as input block and
+ // grid size parameters for the launch.
+ //
+ // Then, we will go over the parameters of func, and find a host-mapping for each type
+ // by calling mapTypeToCudaHostType(IRType*), which turns structs into tuples, and
+ // IRTensorViewType to IRTorchTensorType.
+ //
+ // Finally, we will create a CudaHost function and transfer the name of func over to
+ // the generated method.
+ //
+ // The function body will first perform any conversion logic needed to convert the
+ // parameters from the CudaHost types to the types of func, and then use dispatch_kernel
+ // to dispatch func with the given block and grid size.
+ //
+
+ // Create new function.
+ IRBuilder builder(func->getModule());
+
+ auto hostFunc = builder.createFunc();
+ builder.setInsertInto(hostFunc);
+ builder.emitBlock();
+
+ List<IRType*> hostParamTypes;
+
+ // Add the two uint3 parameters
+ auto uint3Type = builder.getVectorType(builder.getUIntType(), 3);
+
+ auto blockSizeParam = builder.emitParam(uint3Type);
+ hostParamTypes.add(uint3Type);
+ builder.addNameHintDecoration(blockSizeParam, UnownedStringSlice("__blockSize"));
+
+ auto gridSizeParam = builder.emitParam(uint3Type);
+ hostParamTypes.add(uint3Type);
+ builder.addNameHintDecoration(gridSizeParam, UnownedStringSlice("__gridSize"));
+
+ List<IRInst*> mappedParams;
+ for (auto param : func->getFirstBlock()->getParams())
+ {
+ IRType* hostParamType;
+ mappedParams.add(generateHostParamForCUDAParam(&builder, param, sink, &hostParamType));
+ hostParamTypes.add(hostParamType);
+ markTypeForPyExport(param->getDataType(), sink); // Should we be marking the host type?
+ }
+
+ // Dispatch the original function.
+ builder.emitDispatchKernelInst(
+ builder.getVoidType(),
+ func,
+ blockSizeParam,
+ gridSizeParam,
+ mappedParams.getCount(),
+ mappedParams.getBuffer());
+
+ builder.emitReturn();
+
+ IRFuncType* hostFuncType = builder.getFuncType(hostParamTypes, builder.getVoidType());
+ hostFunc->setFullType(hostFuncType);
+
+ // Add a torch entry point decoration to the host function to mark
+ // for further processing.
+ //
+ if (auto pybindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>())
+ {
+ // Mark for further processing of torch-specific insts.
+ builder.addTorchEntryPointDecoration(hostFunc, pybindCudaHint->getFunctionName());
+ // Mark for host-side emit logic.
+ builder.addCudaHostDecoration(hostFunc);
+ // Keep alive. This method will be accessed externally.
+ builder.addPublicDecoration(hostFunc);
+ builder.addKeepAliveDecoration(hostFunc);
+ }
+
+ if (auto externCppHint = func->findDecoration<IRExternCppDecoration>())
+ {
+ // Transfer to the host function.
+ builder.addExternCppDecoration(hostFunc, externCppHint->getName());
+ }
+
+ if (auto exportInfoHint = func->findDecoration<IRAutoPyBindExportInfoDecoration>())
+ generateReflectionFunc(&builder, func, hostFunc);
+
+ return hostFunc;
+}
+
void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
{
List<IRFunc*> workList;
List<IRFunc*> cudaKernels;
+ List<IRFunc*> autoBindRequests;
+ List<IRType*> typesToExport;
for (auto globalInst : module->getGlobalInsts())
{
- auto func = as<IRFunc>(globalInst);
- if (!func)
- continue;
- if (func->findDecoration<IRTorchEntryPointDecoration>())
- {
- workList.add(func);
- }
- else if (func->findDecoration<IRCudaKernelDecoration>())
+ if (auto func = as<IRFunc>(globalInst))
{
- cudaKernels.add(func);
+ if (func->findDecoration<IRAutoPyBindCudaDecoration>())
+ {
+ autoBindRequests.add(func);
+ }
+ if (func->findDecoration<IRTorchEntryPointDecoration>())
+ {
+ workList.add(func);
+ }
+ else if (func->findDecoration<IRCudaKernelDecoration>())
+ {
+ cudaKernels.add(func);
+ }
+ else
+ {
+ // Remove all other export decorations if this is not a cuda host func.
+ if (auto decor = func->findDecoration<IRPublicDecoration>())
+ decor->removeAndDeallocate();
+ if (auto decor = func->findDecoration<IRHLSLExportDecoration>())
+ decor->removeAndDeallocate();
+ if (auto decor = func->findDecoration<IRKeepAliveDecoration>())
+ decor->removeAndDeallocate();
+ if (auto decor = func->findDecoration<IRDllExportDecoration>())
+ decor->removeAndDeallocate();
+ }
}
- else
+ }
+
+ // Generate CUDA wrappers for all functions that have the auto-bind decoration.
+ for (auto func : autoBindRequests)
+ {
+ if (auto hostFunc = generateCUDAWrapperForFunc(func, sink))
{
- // Remove all other export decorations if this is not a cuda host func.
- if (auto decor = func->findDecoration<IRPublicDecoration>())
- decor->removeAndDeallocate();
- if (auto decor = func->findDecoration<IRHLSLExportDecoration>())
- decor->removeAndDeallocate();
- if (auto decor = func->findDecoration<IRKeepAliveDecoration>())
- decor->removeAndDeallocate();
- if (auto decor = func->findDecoration<IRDllExportDecoration>())
- decor->removeAndDeallocate();
+ // Add generated wrapper to worklist for python bindings.
+ workList.add(hostFunc);
}
}
@@ -355,6 +814,20 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
block = nextBlock;
}
}
+
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto type = as<IRType>(globalInst))
+ {
+ if (type->findDecoration<IRPyExportDecoration>())
+ {
+ typesToExport.add(type);
+ }
+ }
+ }
+
+ for (auto type : typesToExport)
+ generateReflectionForType(type, sink);
}
// Remove all [TorchEntryPoint] functions when emitting CUDA source.
@@ -372,4 +845,164 @@ void removeTorchKernels(IRModule* module)
inst->removeAndDeallocate();
}
+void handleAutoBindNames(IRModule* module)
+{
+ // We need to rewrite extern-cpp names for functions that have an auto-bind decoration.
+ // since the name needs to be used for the host function.
+ //
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (globalInst->findDecoration<IRAutoPyBindCudaDecoration>())
+ {
+ // Find an extern decoration on the original function, and append a prefix to the name.
+ if (auto externCppHint = globalInst->findDecoration<IRExternCppDecoration>())
+ {
+ IRBuilder builder(module);
+
+ // Change the name of the original function.
+ StringBuilder nameBuilder;
+ nameBuilder << "__kernel__" << externCppHint->getName();
+ externCppHint->removeAndDeallocate();
+ builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice());
+ }
+ }
+ }
+}
+
+void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink)
+{
+ SLANG_UNUSED(sink);
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (!as<IRFunc>(globalInst))
+ continue;
+
+ // Look for methods marked with auto-bind and are differentiable.
+ if (globalInst->findDecoration<IRAutoPyBindCudaDecoration>())
+ {
+ if(globalInst->findDecoration<IRForwardDifferentiableDecoration>() ||
+ globalInst->findDecoration<IRBackwardDifferentiableDecoration>())
+ {
+ // We'll generate a wrapper for this method that calls fwd_diff(fn)
+ // but an important thing to note is that we won't actually employ the usual
+ // differentiable typing rules. We'll assume none of the parameters are
+ // differentiable & throw a warning if some are. This is because, for the auto-binding
+ // scenario, we expect to only see tensor types, and their differentiation is handled using
+ // tensor _pair_ types which handle the differentiable loads/stores through custom derivatives
+ //
+ // For now, the user is expected to explicitly use the tensor pair types, so we will simply copy over
+ // the original function's signature.
+ // In the future, when we update the type system to be able to specify the corresponding pair type,
+ // we can update this logic.
+ //
+
+ // Create a new wrapper function.
+ IRBuilder builder(module);
+ auto func = cast<IRFunc>(globalInst);
+ auto wrapperFunc = builder.createFunc();
+ builder.setInsertInto(wrapperFunc);
+ builder.emitBlock();
+
+ // Clone the parameter list.
+ List<IRInst*> params;
+ for (auto param : func->getFirstBlock()->getParams())
+ {
+ params.add(builder.emitParam(param->getFullType()));
+ }
+
+ wrapperFunc->setFullType(func->getFullType());
+
+ auto fwdDiffFunc = builder.emitForwardDifferentiateInst(func->getFullType(), func);
+ auto fwdDiffCall = builder.emitCallInst(
+ func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer());
+
+ builder.emitReturn(fwdDiffCall);
+
+ // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well.
+ if (auto kernelHint = func->findDecoration<IRCudaKernelDecoration>())
+ builder.addCudaKernelDecoration(wrapperFunc);
+
+ // Add an auto-pybind-cuda decoration to the wrapper function to further generate the
+ // host-side binding for the derivative kernel.
+ //
+ {
+ auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>();
+ StringBuilder nameBuilder;
+ nameBuilder << autoPyBindCudaHint->getFunctionName() << "_fwd_diff";
+ builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
+ }
+
+ // Build a name for the wrapper function: <original_name>_fwd_diff
+ if (auto externCppHint = func->findDecoration<IRExternCppDecoration>())
+ {
+ StringBuilder nameBuilder;
+ nameBuilder << externCppHint->getName() << "_fwd_diff";
+ builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
+ }
+
+ builder.addPublicDecoration(wrapperFunc);
+ builder.addKeepAliveDecoration(wrapperFunc);
+
+ builder.addCudaKernelForwardDerivativeDecoration(func, wrapperFunc);
+ }
+
+ if (globalInst->findDecoration<IRBackwardDifferentiableDecoration>())
+ {
+ // The reasoning for the reverse-mode is the same as the forward-mode version
+ // (see above)
+ //
+
+ // Create a new wrapper function.
+ IRBuilder builder(module);
+ auto func = cast<IRFunc>(globalInst);
+ auto wrapperFunc = builder.createFunc();
+ builder.setInsertInto(wrapperFunc);
+ builder.emitBlock();
+
+ // Clone the parameter list.
+ List<IRInst*> params;
+ for (auto param : func->getFirstBlock()->getParams())
+ {
+ params.add(builder.emitParam(param->getFullType()));
+ }
+
+ wrapperFunc->setFullType(func->getFullType());
+
+ auto fwdDiffFunc = builder.emitBackwardDifferentiateInst(func->getFullType(), func);
+ auto fwdDiffCall = builder.emitCallInst(
+ func->getResultType(), fwdDiffFunc, params.getCount(), params.getBuffer());
+
+ builder.emitReturn(fwdDiffCall);
+
+ // If the original func is a CUDA kernel, mark the wrapper as a CUDA kernel as well.
+ if (auto kernelHint = func->findDecoration<IRCudaKernelDecoration>())
+ builder.addCudaKernelDecoration(wrapperFunc);
+
+ // Add an auto-pybind-cuda decoration to the wrapper function to further generate the
+ // host-side binding for the derivative kernel.
+ //
+ {
+ auto autoPyBindCudaHint = func->findDecoration<IRAutoPyBindCudaDecoration>();
+ StringBuilder nameBuilder;
+ nameBuilder << autoPyBindCudaHint->getFunctionName() << "_bwd_diff";
+ builder.addAutoPyBindCudaDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
+ }
+
+ // Build a name for the wrapper function: <original_name>_bwd_diff
+ if (auto externCppHint = func->findDecoration<IRExternCppDecoration>())
+ {
+ StringBuilder nameBuilder;
+ nameBuilder << externCppHint->getName() << "_bwd_diff";
+ builder.addExternCppDecoration(wrapperFunc, nameBuilder.getUnownedSlice());
+ }
+
+ builder.addPublicDecoration(wrapperFunc);
+ builder.addKeepAliveDecoration(wrapperFunc);
+
+ builder.addCudaKernelBackwardDerivativeDecoration(func, wrapperFunc);
+ }
+ }
+ }
+}
+
}
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.h b/source/slang/slang-ir-pytorch-cpp-binding.h
index c35b6a8eb..dd7dcc9a4 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.h
+++ b/source/slang/slang-ir-pytorch-cpp-binding.h
@@ -7,6 +7,8 @@ class DiagnosticSink;
void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink);
void removeTorchKernels(IRModule* module);
+void handleAutoBindNames(IRModule* module);
+void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink);
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 8d1a86d4d..4f64087d8 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1353,6 +1353,17 @@ static void addLinkageDecoration(
builder->addPublicDecoration(inst);
builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice());
}
+ else if (as<AutoPyBindCudaAttribute>(modifier))
+ {
+ builder->addAutoPyBindCudaDecoration(inst, decl->getName()->text.getUnownedSlice());
+ builder->addAutoPyBindExportInfoDecoration(inst);
+ }
+ else if (auto pyExportModifier = as<PyExportAttribute>(modifier))
+ {
+ builder->addPyExportDecoration(inst, pyExportModifier->name.getLength()
+ ? pyExportModifier->name.getUnownedSlice()
+ : decl->getName()->text.getUnownedSlice());
+ }
else if (as<KnownBuiltinAttribute>(modifier))
{
// We add this to the internal instruction, like other name-like