summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--slang.h1
-rw-r--r--source/core/slang-type-text-util.cpp1
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/slang-ast-modifier.h5
-rwxr-xr-xsource/slang/slang-compiler.h1
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h5
-rw-r--r--source/slang/slang-lower-to-ir.cpp5
-rw-r--r--source/slang/slang-options.cpp4
-rw-r--r--tests/autodiff/cuda-kernel-export.slang29
10 files changed, 54 insertions, 2 deletions
diff --git a/slang.h b/slang.h
index d79d7829c..1a72c4348 100644
--- a/slang.h
+++ b/slang.h
@@ -571,6 +571,7 @@ extern "C"
SLANG_SHADER_HOST_CALLABLE, ///< A CPU target that makes the compiled shader code available to be run immediately
SLANG_CUDA_SOURCE, ///< Cuda source
SLANG_PTX, ///< PTX
+ SLANG_CUDA_OBJECT_CODE, ///< Object code that contains CUDA functions.
SLANG_OBJECT_CODE, ///< Object code that can be used for later linking
SLANG_HOST_CPP_SOURCE, ///< C++ code for host library or executable.
SLANG_HOST_HOST_CALLABLE, ///<
diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp
index 2ba3011d8..d37051e47 100644
--- a/source/core/slang-type-text-util.cpp
+++ b/source/core/slang-type-text-util.cpp
@@ -83,6 +83,7 @@ static const CompileTargetInfo s_compileTargetInfos[] =
{ SLANG_SHADER_SHARED_LIBRARY, "dll,so", "sharedlib,sharedlibrary,dll" },
{ SLANG_CUDA_SOURCE, "cu", "cuda,cu" },
{ SLANG_PTX, "ptx", "ptx" },
+ { SLANG_CUDA_OBJECT_CODE, "obj,o", "cuobj,cubin" },
{ SLANG_SHADER_HOST_CALLABLE, "", "host-callable,callable" },
{ SLANG_OBJECT_CODE, "obj,o", "object-code" },
{ SLANG_HOST_HOST_CALLABLE, "", "host-host-callable" },
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 2507c22dd..ad3817d9a 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2988,6 +2988,9 @@ attribute_syntax [DllImport(modulePath: String)] : DllImportAttribute;
__attributeTarget(FuncDecl)
attribute_syntax [DllExport] : DllExportAttribute;
+__attributeTarget(FuncDecl)
+attribute_syntax [CudaDeviceExport] : CudaDeviceExportAttribute;
+
__attributeTarget(InterfaceDecl)
attribute_syntax [COM(guid: String)] : ComInterfaceAttribute;
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 80c770c3a..c58b7de21 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1063,6 +1063,11 @@ class DllExportAttribute : public Attribute
SLANG_AST_CLASS(DllExportAttribute)
};
+class CudaDeviceExportAttribute : public Attribute
+{
+ SLANG_AST_CLASS(CudaDeviceExportAttribute)
+};
+
class DerivativeMemberAttribute : public Attribute
{
SLANG_AST_CLASS(DerivativeMemberAttribute)
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index e6b173e2a..784026761 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -81,6 +81,7 @@ namespace Slang
ShaderHostCallable = SLANG_SHADER_HOST_CALLABLE,
CUDASource = SLANG_CUDA_SOURCE,
PTX = SLANG_PTX,
+ CUDAObjectCode = SLANG_CUDA_OBJECT_CODE,
ObjectCode = SLANG_OBJECT_CODE,
HostHostCallable = SLANG_HOST_HOST_CALLABLE,
CountOf = SLANG_TARGET_COUNT_OF,
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 71d9315bd..4516a6bc3 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -712,6 +712,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(DllImportDecoration, dllImport, 2, 0)
/// An dllExport decoration marks a function as an export symbol. Slang will generate a native wrapper function that is exported to DLL.
INST(DllExportDecoration, dllExport, 1, 0)
+ /// An cudaDeviceExport decoration marks a function to be exported as a cuda __device__ function.
+ INST(CudaDeviceExportDecoration, cudaDeviceExport, 1, 0)
/// Marks an interface as a COM interface declaration.
INST(ComInterfaceDecoration, COMInterface, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 9c4c1f4e2..cf58c22d0 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3768,6 +3768,11 @@ public:
addDecoration(value, kIROp_DllExportDecoration, getStringValue(functionName));
}
+ void addCudaDeviceExportDecoration(IRInst* value, UnownedStringSlice const& functionName)
+ {
+ addDecoration(value, kIROp_CudaDeviceExportDecoration, getStringValue(functionName));
+ }
+
void addEntryPointDecoration(IRInst* value, Profile profile, UnownedStringSlice const& name, UnownedStringSlice const& moduleName)
{
IRInst* operands[] = { getIntValue(getIntType(), profile.raw), getStringValue(name), getStringValue(moduleName) };
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 5e6213205..e6b6b5c61 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1155,6 +1155,11 @@ static void addLinkageDecoration(
builder->addDllExportDecoration(inst, decl->getName()->text.getUnownedSlice());
builder->addPublicDecoration(inst);
}
+ if (decl->findModifier<CudaDeviceExportAttribute>())
+ {
+ builder->addCudaDeviceExportDecoration(inst, decl->getName()->text.getUnownedSlice());
+ builder->addPublicDecoration(inst);
+ }
}
static void addLinkageDecoration(
diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp
index fea9a1c71..944162acc 100644
--- a/source/slang/slang-options.cpp
+++ b/source/slang/slang-options.cpp
@@ -2185,8 +2185,8 @@ struct OptionsParser
// If we don't have any raw outputs but do have a raw target,
// and output type is callable, add an empty' rawOutput.
- if (rawOutputs.getCount() == 0 &&
- rawTargets.getCount() == 1 &&
+ if (rawOutputs.getCount() == 0 &&
+ rawTargets.getCount() == 1 &&
ArtifactDescUtil::makeDescForCompileTarget(asExternal(rawTargets[0].format)).kind == ArtifactKind::HostCallable)
{
RawOutput rawOutput;
diff --git a/tests/autodiff/cuda-kernel-export.slang b/tests/autodiff/cuda-kernel-export.slang
new file mode 100644
index 000000000..0db4d8cea
--- /dev/null
+++ b/tests/autodiff/cuda-kernel-export.slang
@@ -0,0 +1,29 @@
+//DISABLED_TEST:SIMPLE: -target cuda -line-directive-mode none
+
+// Verify that we can output a cuda device function with [CudaDeviceExport].
+// Disabled until we have FileCheck.
+
+struct MixedType : IDifferentiable
+{
+ no_diff float noDiffField;
+ float field;
+}
+
+[BackwardDifferentiable]
+float f1(MixedType m)
+{
+ return 2.0 * m.field;
+}
+
+[BackwardDifferentiable]
+float f(MixedType m)
+{
+ MixedType m1 = { m.noDiffField, m.field };
+ return f1(m1);
+}
+
+[CudaDeviceExport]
+void diffF(inout DifferentialPair<MixedType> m, float dout)
+{
+ __bwd_diff(f)(m, dout);
+}