diff options
Diffstat (limited to 'source')
30 files changed, 911 insertions, 93 deletions
diff --git a/source/compiler-core/slang-artifact-desc-util.cpp b/source/compiler-core/slang-artifact-desc-util.cpp index ca7dcb70f..3be0448c4 100644 --- a/source/compiler-core/slang-artifact-desc-util.cpp +++ b/source/compiler-core/slang-artifact-desc-util.cpp @@ -273,6 +273,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL case SLANG_C_SOURCE: return Desc::make(Kind::Source, Payload::C, Style::Kernel, 0); case SLANG_CPP_SOURCE: return Desc::make(Kind::Source, Payload::Cpp, Style::Kernel, 0); case SLANG_HOST_CPP_SOURCE: return Desc::make(Kind::Source, Payload::Cpp, Style::Host, 0); + case SLANG_CPP_PYTORCH_BINDING: return Desc::make(Kind::Source, Payload::Cpp, Style::Host, 0); case SLANG_HOST_EXECUTABLE: return Desc::make(Kind::Executable, Payload::HostCPU, Style::Host, 0); case SLANG_SHADER_SHARED_LIBRARY: return Desc::make(Kind::SharedLibrary, Payload::HostCPU, Style::Kernel, 0); case SLANG_SHADER_HOST_CALLABLE: return Desc::make(Kind::HostCallable, Payload::HostCPU, Style::Kernel, 0); diff --git a/source/compiler-core/slang-artifact.h b/source/compiler-core/slang-artifact.h index cc4d3e9fd..65d4d1bf9 100644 --- a/source/compiler-core/slang-artifact.h +++ b/source/compiler-core/slang-artifact.h @@ -208,7 +208,6 @@ enum class ArtifactStyle : uint8_t Kernel, ///< Compiled as `GPU kernel` style. Host, ///< Compiled in `host` style - Obfuscated, ///< Holds something specific to obfuscation, such as an obfuscated source map CountOf, diff --git a/source/core/slang-type-convert-util.cpp b/source/core/slang-type-convert-util.cpp index 6e6598357..c763b2835 100644 --- a/source/core/slang-type-convert-util.cpp +++ b/source/core/slang-type-convert-util.cpp @@ -17,6 +17,7 @@ namespace Slang case SLANG_HLSL: return SLANG_SOURCE_LANGUAGE_HLSL; case SLANG_C_SOURCE: return SLANG_SOURCE_LANGUAGE_C; case SLANG_CPP_SOURCE: return SLANG_SOURCE_LANGUAGE_CPP; + case SLANG_CPP_PYTORCH_BINDING:return SLANG_SOURCE_LANGUAGE_CPP; case SLANG_HOST_CPP_SOURCE: return SLANG_SOURCE_LANGUAGE_CPP; case SLANG_CUDA_SOURCE: return SLANG_SOURCE_LANGUAGE_CUDA; default: break; diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp index d37051e47..9d2f93ba3 100644 --- a/source/core/slang-type-text-util.cpp +++ b/source/core/slang-type-text-util.cpp @@ -78,6 +78,7 @@ static const CompileTargetInfo s_compileTargetInfos[] = { SLANG_SPIRV_ASM, "spv-asm", "spirv-asm,spirv-assembly" }, { SLANG_C_SOURCE, "c", "c" }, { SLANG_CPP_SOURCE, "cpp,c++,cxx", "cpp,c++,cxx" }, + { SLANG_CPP_PYTORCH_BINDING, "cpp,c++,cxx", "torch,torch-binding,torch-cpp,torch-cpp-binding" }, { SLANG_HOST_CPP_SOURCE, "cpp,c++,cxx", "host-cpp,host-c++,host-cxx"}, { SLANG_HOST_EXECUTABLE,"exe", "exe,executable" }, { SLANG_SHADER_SHARED_LIBRARY, "dll,so", "sharedlib,sharedlibrary,dll" }, diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 82a60a612..c45ad5bd6 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -3081,6 +3081,9 @@ __attributeTarget(FuncDecl) attribute_syntax [DllExport] : DllExportAttribute; __attributeTarget(FuncDecl) +attribute_syntax [TorchEntryPoint] : TorchEntryPointAttribute; + +__attributeTarget(FuncDecl) attribute_syntax [CudaDeviceExport] : CudaDeviceExportAttribute; __attributeTarget(FuncDecl) diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index bbe94dbc2..d5b70bbb3 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -9,7 +9,6 @@ attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute; - __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; @@ -26,6 +25,89 @@ attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [NoDiffThis] : NoDiffThisAttribute; +__generic<T> +__magic_type(TensorViewType) +__intrinsic_type($(kIROp_TensorViewType)) +struct TensorView +{ + __target_intrinsic(cuda, "$0.data_ptr<$G0>()") + Ptr<T> data_ptr(); + + __implicit_conversion($(kConversionCost_ImplicitDereference)) + __intrinsic_op($(kIROp_TorchTensorGetView)) + __init(TorchTensor<T> t); + + __target_intrinsic(cuda, "$0.load<$G0>($1)") + T load(uint x); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2)") + T load(uint x, uint y); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3)") + T load(uint x, uint y, uint z); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4)") + T load(uint x, uint y, uint z, uint w); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4, $5)") + T load(uint i0, uint i1, uint i2, uint i3, uint i4); + + __target_intrinsic(cuda, "$0.store<$G0>($1, $2)") + void store(uint x, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3)") + void store(uint x, uint y, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4)") + void store(uint x, uint y, uint z, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5)") + void store(uint x, uint y, uint z, uint w, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5, $6)") + void store(uint i0, uint i1, uint i2, uint i3, uint i4, T val); + + __target_intrinsic(cuda, "$0.dimensionCount") + uint dims(); + + __target_intrinsic(cuda, "$0.sizes[$1]") + uint size(uint i); + + __target_intrinsic(cuda, "$0.strides[$1]") + uint stride(uint i); +} + +__generic<T> +__intrinsic_type($(kIROp_TorchTensorType)) +struct TorchTensor +{ + __intrinsic_op($(kIROp_TorchTensorGetView)) + TensorView<T> getView(); + + __target_intrinsic(cuda, "$0.dims()") + __target_intrinsic(cpp, "$0.dims()") + uint dims(); + + __target_intrinsic(cuda, "$0.size($1)") + __target_intrinsic(cpp, "$0.size($1)") + uint size(uint i); + + __target_intrinsic(cuda, "$0.stride($1)") + __target_intrinsic(cpp, "$0.stride($1)") + uint stride(uint i); + + __target_intrinsic(cuda, "$0.data_ptr<$G0>()") + __target_intrinsic(cpp, "$0.data_ptr<$G0>()") + Ptr<T> data_ptr(); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x, uint y); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x, uint y, uint z); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x, uint y, uint z, uint w); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4); +} + __generic<T: IDifferentiable> __intrinsic_op($(kIROp_MakeDifferentialPairUserCode)) DifferentialPair<T> diffPair(T primal, T.Differential diff); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index d417e3b7c..4c4aaa4f0 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6834,3 +6834,13 @@ void debugBreak(); __specialized_for_target(glsl) [[vk::spirv_instruction(1, "NonSemantic.DebugBreak")]] void debugBreak(); + + +__target_intrinsic(cuda, "(threadIdx)") +uint3 cudaThreadIdx(); + +__target_intrinsic(cuda, "(blockIdx)") +uint3 cudaBlockIdx(); + +__target_intrinsic(cuda, "(blockDim)") +uint3 cudaBlockDim(); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 0d2e27e5f..00a6570ef 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1055,6 +1055,11 @@ class DllExportAttribute : public Attribute SLANG_AST_CLASS(DllExportAttribute) }; +class TorchEntryPointAttribute : public Attribute +{ + SLANG_AST_CLASS(TorchEntryPointAttribute) +}; + class CudaDeviceExportAttribute : public Attribute { SLANG_AST_CLASS(CudaDeviceExportAttribute) diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index fdbd56377..1fed2d52a 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -275,6 +275,13 @@ BasicExpressionType* BasicExpressionType::_getScalarTypeOverride() return this; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TensorViewType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +Type* TensorViewType::getElementType() +{ + return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); +} + + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! VectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void VectorExpressionType::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 47608405a..cb3fde9f9 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -524,6 +524,15 @@ private: MatrixExpressionType(Type*, IntVal*, IntVal*) {} }; +class TensorViewType : public BuiltinType +{ + SLANG_AST_CLASS(TensorViewType) + + Type* getElementType(); +private: + TensorViewType(Type*) {} +}; + // Base class for built in string types class StringTypeBase : public BuiltinType { diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 8cbe12ef0..ff38bbbde 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -470,6 +470,7 @@ namespace Slang case CodeGenTarget::CUDASource: case CodeGenTarget::CPPSource: case CodeGenTarget::HostCPPSource: + case CodeGenTarget::PyTorchCppBinding: case CodeGenTarget::CSource: { return PassThroughMode::None; @@ -1570,6 +1571,7 @@ namespace Slang case CodeGenTarget::CUDASource: case CodeGenTarget::CPPSource: case CodeGenTarget::HostCPPSource: + case CodeGenTarget::PyTorchCppBinding: case CodeGenTarget::CSource: { RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index b287b21eb..b49ce4fc3 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -75,6 +75,7 @@ namespace Slang DXILAssembly = SLANG_DXIL_ASM, CSource = SLANG_C_SOURCE, CPPSource = SLANG_CPP_SOURCE, + PyTorchCppBinding = SLANG_CPP_PYTORCH_BINDING, HostCPPSource = SLANG_HOST_CPP_SOURCE, HostExecutable = SLANG_HOST_EXECUTABLE, ShaderSharedLibrary = SLANG_SHADER_SHARED_LIBRARY, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index c3e0adbca..39ceb6678 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -664,6 +664,8 @@ DIAGNOSTIC(54002, Error, meshOutputMustBeArray, "HLSL style mesh shader outputs DIAGNOSTIC(54003, Error, meshOutputArrayMustHaveSize, "HLSL style mesh shader output arrays must have a length specified") DIAGNOSTIC(54004, Warning, unnecessaryHLSLMeshOutputModifier, "Unnecessary HLSL style mesh shader output modifier") +DIAGNOSTIC(55101, Error, invalidTorchKernelReturnType, "'$0' is not a valid return type for a pytorch kernel function.") +DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid parameter type for a pytorch kernel function.") // // 8xxxx - Issues specific to a particular library/technology/platform/etc. diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 356f1c7ce..ebc312560 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -129,6 +129,7 @@ Index LocationTracker::getValue(Kind kind, IRInst* inst, IRDecoration* decoratio } case CodeGenTarget::CPPSource: case CodeGenTarget::HostCPPSource: + case CodeGenTarget::PyTorchCppBinding: { return SourceLanguage::CPP; } diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 2a2ae06c6..346926712 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -304,6 +304,18 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S out << ">"; return SLANG_OK; } + case kIROp_TargetTupleType: + { + out << "std::tuple<"; + for (UInt i = 0; i < type->getOperandCount(); i++) + { + if (i > 0) out << ", "; + auto elementType = (IRType*)type->getOperand(i); + SLANG_RETURN_ON_FAIL(calcTypeName(elementType, target, out)); + } + out << ">"; + return SLANG_OK; + } default: { if (isNominalOp(type->getOp())) @@ -1187,6 +1199,19 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut return true; } + case kIROp_MakeTargetTuple: + { + m_writer->emit("std::make_tuple("); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (i > 0) + m_writer->emit(", "); + auto arg = inst->getOperand(i); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit(")"); + return true; + } case kIROp_CastFloatToInt: case kIROp_CastIntToFloat: case kIROp_FloatCast: diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index 846b3b1f2..d2fa892ba 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -146,6 +146,11 @@ SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, out << prefix << vecCount; return SLANG_OK; } + case kIROp_TensorViewType: + { + out << "TensorView"; + return SLANG_OK; + } default: { if (isNominalOp(type->getOp())) diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp new file mode 100644 index 000000000..ef67c520a --- /dev/null +++ b/source/slang/slang-emit-torch.cpp @@ -0,0 +1,181 @@ +// slang-emit-torch.cpp +#include "slang-emit-torch.h" + +#include "../core/slang-writer.h" + +#include "slang-emit-source-writer.h" +#include "slang-mangled-lexer.h" + +#include <assert.h> + +namespace Slang +{ +bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) +{ + switch (inst->getOp()) + { + default: + { + return Super::tryEmitInstExprImpl(inst, inOuterPrec); + } + case kIROp_MakeTensorView: + { + m_writer->emit("make_tensor_view("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + case kIROp_CudaKernelLaunch: + { + m_writer->emit("cudaLaunchKernel("); + // func + m_writer->emit("(const void*)("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // gridDim + m_writer->emit("slang_bit_cast<dim3>("); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // blockDim + m_writer->emit("slang_bit_cast<dim3>("); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit("), "); + + // args + emitOperand(inst->getOperand(3), getInfo(EmitOp::General)); + m_writer->emit(", "); + + // shared mem + m_writer->emit("slangGetCudaKernelSharedMemSize((const void*)("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")), "); + + // stream + m_writer->emit("((cudaStream_t)"); + emitOperand(inst->getOperand(4), getInfo(EmitOp::General)); + m_writer->emit("))"); + return true; + } + case kIROp_TorchGetCudaStream: + { + m_writer->emit("at::cuda::getCurrentCUDAStream()"); + return true; + } + case kIROp_AllocateTorchTensor: + { + /* + Emit something like: + ``` + torch::Tensor out = torch::empty({ dimX, dimY, dimZ, ... }, + torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); + ``` + */ + m_writer->emit("torch::empty({ "); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (i > 0) + m_writer->emit(", "); + auto arg = inst->getOperand(i); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit("}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::"); + switch (inst->getDataType()->getOperand(0)->getOp()) + { + case kIROp_FloatType: + m_writer->emit("kFloat32"); + break; + case kIROp_HalfType: + m_writer->emit("kFloat16"); + break; + case kIROp_DoubleType: + m_writer->emit("kFloat64"); + break; + case kIROp_UInt8Type: + m_writer->emit("kUInt8"); + break; + case kIROp_UInt16Type: + m_writer->emit("kUInt16"); + break; + case kIROp_UIntType: + m_writer->emit("kUInt32"); + break; + case kIROp_UInt64Type: + m_writer->emit("kUInt64"); + break; + case kIROp_Int8Type: + m_writer->emit("kInt8"); + break; + case kIROp_Int16Type: + m_writer->emit("kInt16"); + break; + case kIROp_IntType: + m_writer->emit("kInt32"); + break; + case kIROp_Int64Type: + m_writer->emit("kInt64"); + break; + default: + SLANG_UNEXPECTED("unknown scalar type in allocTorchTensor"); + break; + } + m_writer->emit("))"); + return true; + } + } +} + +SlangResult TorchCppSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) +{ + switch (type->getOp()) + { + default: + return Super::calcTypeName(type, target, out); + case kIROp_TensorViewType: + { + out << "TensorView"; + return SLANG_OK; + } + case kIROp_TorchTensorType: + { + out << "torch::Tensor"; + return SLANG_OK; + } + case kIROp_TorchKernelMemoryAllocatorType: + { + out << "CudaTaskMemoryAllocator"; + return SLANG_OK; + } + } +} + +void TorchCppSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink) +{ + Super::emitModuleImpl(module, sink); + + // Emit PyBind declarations. + m_writer->emit("PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n"); + m_writer->indent(); + for (auto inst : module->getGlobalInsts()) + { + auto func = as<IRFunc>(inst); + if (!func) continue; + auto decor = func->findDecoration<IRTorchEntryPointDecoration>(); + if (!decor) continue; + m_writer->emit("m.def("); + emitStringLiteral(decor->getFunctionName()); + m_writer->emit(", &"); + m_writer->emit(decor->getFunctionName()); + m_writer->emit(", "); + emitStringLiteral(decor->getFunctionName()); + m_writer->emit(");\n"); + } + m_writer->dedent(); + m_writer->emit("}\n"); + +} + +} // namespace Slang diff --git a/source/slang/slang-emit-torch.h b/source/slang/slang-emit-torch.h new file mode 100644 index 000000000..84ce42331 --- /dev/null +++ b/source/slang/slang-emit-torch.h @@ -0,0 +1,28 @@ +// slang-emit-torch.h +#ifndef SLANG_EMIT_TORCH_H +#define SLANG_EMIT_TORCH_H + +#include "slang-emit-cpp.h" + +namespace Slang +{ + +class TorchCppSourceEmitter : public CPPSourceEmitter +{ +public: + typedef CPPSourceEmitter Super; + + TorchCppSourceEmitter(const Desc& desc) : + Super(desc) + { + } + +protected: + // CPPSourceEmitter overrides + virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) override; + virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) override; + virtual void emitModuleImpl(IRModule* module, DiagnosticSink* sink) override; +}; + +} +#endif diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 1b4eed8fd..fe72efcc7 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -56,6 +56,7 @@ #include "slang-ir-glsl-liveness.h" #include "slang-ir-string-hash.h" #include "slang-ir-simplify-for-emit.h" +#include "slang-ir-pytorch-cpp-binding.h" #include "slang-legalize-types.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" @@ -74,6 +75,7 @@ #include "slang-emit-hlsl.h" #include "slang-emit-cpp.h" #include "slang-emit-cuda.h" +#include "slang-emit-torch.h" #include "../compiler-core/slang-artifact-desc-util.h" #include "../compiler-core/slang-artifact-util.h" @@ -83,6 +85,7 @@ #include <assert.h> Slang::String get_slang_cpp_host_prelude(); +Slang::String get_slang_torch_prelude(); namespace Slang { @@ -402,6 +405,18 @@ Result linkAndOptimizeIR( finalizeSpecialization(irModule); + switch (target) + { + case CodeGenTarget::PyTorchCppBinding: + generatePyTorchCppBinding(irModule, sink); + break; + case CodeGenTarget::CUDASource: + removeTorchKernels(irModule); + break; + default: + break; + } + // If we have a target that is GPU like we use the string hashing mechanism // but for that to work we need to inline such that calls (or returns) of strings // boil down into getStringHash(stringLiteral) @@ -969,31 +984,39 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr LinkedIR linkedIR; RefPtr<CLikeSourceEmitter> sourceEmitter; - SourceLanguage sourceLanguage = CLikeSourceEmitter::getSourceLanguage(target); - switch (sourceLanguage) + + switch (target) { - case SourceLanguage::CPP: - { - sourceEmitter = new CPPSourceEmitter(desc); - break; - } - case SourceLanguage::GLSL: - { - sourceEmitter = new GLSLSourceEmitter(desc); - break; - } - case SourceLanguage::HLSL: - { - sourceEmitter = new HLSLSourceEmitter(desc); - break; - } - case SourceLanguage::CUDA: + default: + switch (sourceLanguage) { - sourceEmitter = new CUDASourceEmitter(desc); - break; + case SourceLanguage::CPP: + { + sourceEmitter = new CPPSourceEmitter(desc); + break; + } + case SourceLanguage::GLSL: + { + sourceEmitter = new GLSLSourceEmitter(desc); + break; + } + case SourceLanguage::HLSL: + { + sourceEmitter = new HLSLSourceEmitter(desc); + break; + } + case SourceLanguage::CUDA: + { + sourceEmitter = new CUDASourceEmitter(desc); + break; + } + default: break; } - default: break; + break; + case CodeGenTarget::PyTorchCppBinding: + sourceEmitter = new TorchCppSourceEmitter(desc); + break; } if (!sourceEmitter) @@ -1072,16 +1095,23 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr // Emit any front matter sourceEmitter->emitFrontMatter(targetRequest); - // If heterogeneous we output the prelude before everything else - if (isHeterogeneousTarget(target)) - { - sourceWriter.emit(get_slang_cpp_host_prelude()); - } - else + switch (target) { - // Get the prelude - String prelude = session->getPreludeForLanguage(sourceLanguage); - sourceWriter.emit(prelude); + case CodeGenTarget::PyTorchCppBinding: + sourceWriter.emit(get_slang_torch_prelude()); + break; + default: + if (isHeterogeneousTarget(target)) + { + sourceWriter.emit(get_slang_cpp_host_prelude()); + } + else + { + // Get the prelude + String prelude = session->getPreludeForLanguage(sourceLanguage); + sourceWriter.emit(prelude); + } + break; } // Emit anything that goes before the contents of the code generated for the module diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 04e08293f..68f1a28e6 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -67,6 +67,11 @@ INST(Nop, nop, 0, 0) INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE) + INST(TensorViewType, TensorView, 1, HOISTABLE) + INST(TorchTensorType, TorchTensor, 0, HOISTABLE) + INST(TorchKernelMemoryAllocatorType, TorchMemAllocatorType, 0, HOISTABLE) + INST(ArrayListType, ArrayListVector, 1, HOISTABLE) + /* BindExistentialsTypeBase */ // A `BindExistentials<B, T0,w0, T1,w1, ...>` represents @@ -220,6 +225,7 @@ INST(ThisType, this_type, 0, HOISTABLE) INST(RTTIType, rtti_type, 0, HOISTABLE) INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE) INST(TupleType, tuple_type, 0, HOISTABLE) +INST(TargetTupleType, TargetTuple, 0, HOISTABLE) // A type that identifies it's contained type as being emittable as `spirv_literal. INST(SPIRVLiteralType, spirvLiteralType, 1, HOISTABLE) @@ -308,6 +314,7 @@ INST(MakeArray, makeArray, 0, 0) INST(MakeArrayFromElement, makeArrayFromElement, 1, 0) INST(MakeStruct, makeStruct, 0, 0) INST(MakeTuple, makeTuple, 0, 0) +INST(MakeTargetTuple, makeTuple, 0, 0) INST(GetTupleElement, getTupleElement, 2, 0) INST(MakeResultValue, makeResultValue, 1, 0) INST(MakeResultError, makeResultError, 1, 0) @@ -509,24 +516,24 @@ INST(SwizzledStore, swizzledStore, 2, 0) /* IRConditionalbranch */ // conditionalBranch <condition> <trueBlock> <falseBlock> - INST(conditionalBranch, conditionalBranch, 3, 0) +INST(conditionalBranch, conditionalBranch, 3, 0) - // ifElse <condition> <trueBlock> <falseBlock> <mergeBlock> - INST(ifElse, ifElse, 4, 0) - INST_RANGE(ConditionalBranch, conditionalBranch, ifElse) +// ifElse <condition> <trueBlock> <falseBlock> <mergeBlock> +INST(ifElse, ifElse, 4, 0) +INST_RANGE(ConditionalBranch, conditionalBranch, ifElse) - INST(Throw, throw, 1, 0) - // tryCall <successBlock> <failBlock> <callee> <args>... - INST(TryCall, tryCall, 3, 0) - // switch <val> <break> <default> <caseVal1> <caseBlock1> ... - INST(Switch, switch, 3, 0) +INST(Throw, throw, 1, 0) +// tryCall <successBlock> <failBlock> <callee> <args>... +INST(TryCall, tryCall, 3, 0) +// switch <val> <break> <default> <caseVal1> <caseBlock1> ... +INST(Switch, switch, 3, 0) - INST(discard, discard, 0, 0) +INST(discard, discard, 0, 0) - /* IRUnreachable */ - INST(MissingReturn, missingReturn, 0, 0) - INST(Unreachable, unreachable, 0, 0) - INST_RANGE(Unreachable, MissingReturn, Unreachable) +/* IRUnreachable */ +INST(MissingReturn, missingReturn, 0, 0) +INST(Unreachable, unreachable, 0, 0) +INST_RANGE(Unreachable, MissingReturn, Unreachable) INST_RANGE(TerminatorInst, Return, Unreachable) @@ -575,10 +582,10 @@ INST(GetStringHash, getStringHash, 1, 0) INST(WaveGetActiveMask, waveGetActiveMask, 0, 0) - /// trueMask = waveMaskBallot(mask, condition) +/// trueMask = waveMaskBallot(mask, condition) INST(WaveMaskBallot, waveMaskBallot, 2, 0) - /// matchMask = waveMaskBallot(mask, value) +/// matchMask = waveMaskBallot(mask, value) INST(WaveMaskMatch, waveMaskMatch, 2, 0) // Texture sampling operation of the form `t.Sample(s,u)` @@ -604,6 +611,12 @@ INST(GetOptiXHitAttribute, getOptiXHitAttribute, 2, 0) // using a pointer. INST(GetOptiXSbtDataPtr, getOptiXSbtDataPointer, 0, 0) +INST(MakeArrayList, makeArrayList, 0, 0) +INST(MakeTensorView, makeTensorView, 0, 0) +INST(AllocateTorchTensor, allocTorchTensor , 0, 0) +INST(TorchGetCudaStream, TorchGetCudaStream, 0, 0) +INST(TorchTensorGetView, TorchTensorGetView, 0, 0) + /* Decoration */ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) @@ -669,7 +682,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(CudaKernelDecoration, CudaKernel, 0, 0) INST(CudaHostDecoration, CudaHost, 0, 0) - + INST(TorchEntryPointDecoration, TorchEntryPoint, 0, 0) + /// Used to mark parameters that are moved from entry point parameters to global params as coming from the entry point. INST(EntryPointParamDecoration, entryPointParam, 0, 0) @@ -908,6 +922,7 @@ INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0) INST(PrimalSubstitute, PrimalSubstitute, 1, 0) INST(DispatchKernel, DispatchKernel, 3, 0) +INST(CudaKernelLaunch, CudaKernelLaunch, 6, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9bb66823b..4cdf6c749 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -518,6 +518,18 @@ struct IRDllExportDecoration : IRDecoration UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); } }; +struct IRTorchEntryPointDecoration : IRDecoration +{ + enum + { + kOp = kIROp_TorchEntryPointDecoration + }; + IR_LEAF_ISA(TorchEntryPointDecoration) + + IRStringLit* getFunctionNameOperand() { return cast<IRStringLit>(getOperand(0)); } + UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); } +}; + struct IRFormatDecoration : IRDecoration { enum { kOp = kIROp_FormatDecoration }; @@ -936,6 +948,15 @@ struct IRDispatchKernel : IRInst IR_LEAF_ISA(DispatchKernel) }; +struct IRTorchTensorGetView : IRInst +{ + enum + { + kOp = kIROp_TorchTensorGetView + }; + IR_LEAF_ISA(TorchTensorGetView) +}; + // Dictionary item mapping a type with a corresponding // IDifferentiable witness table // @@ -2720,6 +2741,8 @@ public: IRAnyValueType* getAnyValueType(IRInst* size); IRDynamicType* getDynamicType(); + IRTargetTupleType* getTargetTupleType(UInt count, IRType* const* types); + IRTupleType* getTupleType(UInt count, IRType* const* types); IRTupleType* getTupleType(List<IRType*> const& types) { @@ -2775,6 +2798,10 @@ public: IRInst* rowCount, IRInst* columnCount); + IRArrayListType* getArrayListType(IRType* elementType); + IRTensorViewType* getTensorViewType(IRType* elementType); + IRTorchTensorType* getTorchTensorType(); + IRDifferentialPairType* getDifferentialPairType( IRType* valueType, IRInst* witnessTable); @@ -2896,7 +2923,10 @@ public: IRInst* emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn); IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn); IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn); + IRInst* emitDispatchKernelInst(IRType* type, IRInst* baseFn, IRInst* threadGroupSize, IRInst* dispatchSize, Int argCount, IRInst* const* inArgs); + IRInst* emitCudaKernelLaunch(IRInst* baseFn, IRInst* gridDim, IRInst* blockDim, IRInst* argsArray, IRInst* cudaStream); + IRInst* emitGetTorchCudaStream(); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential); @@ -2999,6 +3029,8 @@ public: // Creates an RTTI object. Result is of `IRRTTIType`. IRInst* emitMakeRTTIObject(IRInst* typeInst); + IRInst* emitMakeTargetTuple(IRType* type, UInt count, IRInst* const* args); + IRInst* emitMakeTuple(IRType* type, UInt count, IRInst* const* args); IRInst* emitMakeTuple(UInt count, IRInst* const* args); @@ -3067,6 +3099,11 @@ public: UInt argCount, IRInst* const* args); + IRInst* emitMakeArrayList( + IRType* type, + UInt argCount, + IRInst* const* args); + IRInst* emitMakeArrayFromElement( IRType* type, IRInst* element); @@ -3083,6 +3120,8 @@ public: return emitMakeStruct(type, args.getCount(), args.getBuffer()); } + IRInst* emitMakeTensorView(IRType* type, IRInst* allocator, IRInst* val); + IRInst* emitMakeExistential( IRType* type, IRInst* value, @@ -3785,6 +3824,11 @@ public: addDecoration(value, kIROp_DllExportDecoration, getStringValue(functionName)); } + void addTorchEntryPointDecoration(IRInst* value, UnownedStringSlice const& functionName) + { + addDecoration(value, kIROp_TorchEntryPointDecoration, getStringValue(functionName)); + } + void addCudaDeviceExportDecoration(IRInst* value, UnownedStringSlice const& functionName) { addDecoration(value, kIROp_CudaDeviceExportDecoration, getStringValue(functionName)); diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp new file mode 100644 index 000000000..e33adec1d --- /dev/null +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -0,0 +1,248 @@ +#include "slang-ir-pytorch-cpp-binding.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-diagnostics.h" + +namespace Slang +{ +static bool getHostReturnTypeImpl(List<IRType*>& elementTypes, IRBuilder& builder, IRType* type) +{ + bool isValid = true; + if (as<IRVoidType>(type)) + return true; + if (as<IRBasicType>(type)) + elementTypes.add(type); + else if (as<IRTorchTensorType>(type)) + elementTypes.add(type); + else if (auto vectorType = as<IRVectorType>(type)) + { + auto count = as<IRIntLit>(vectorType->getElementCount()); + if (!count) + { + return false; + } + for (IRIntegerValue i = 0; i < count->getValue(); i++) + { + elementTypes.addRange(vectorType->getElementType()); + } + } + else if (auto arrayType = as<IRArrayType>(type)) + { + auto arraySize = as<IRIntLit>(arrayType->getElementCount()); + if (!arraySize) + { + return false; + } + List<IRType*> subElementTypes; + isValid &= getHostReturnTypeImpl(subElementTypes, builder, arrayType->getElementType()); + for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) + { + elementTypes.addRange(subElementTypes); + } + } + else if (auto structType = as<IRStructType>(type)) + { + for (auto field : structType->getFields()) + { + isValid &= getHostReturnTypeImpl(elementTypes, builder, field->getFieldType()); + } + } + else + { + return false; + } + return isValid; +} + +static IRType* getHostReturnType(IRBuilder& builder, IRType* type) +{ + List<IRType*> types; + bool isValid = getHostReturnTypeImpl(types, builder, type); + if (isValid) + return builder.getTargetTupleType((UInt)types.getCount(), types.getBuffer()); + return nullptr; +} + +static void flattenToTupleImpl(List<IRInst*>& result, IRBuilder& builder, IRInst* val) +{ + auto type = val->getDataType(); + if (as<IRVoidType>(type)) + return; + if (as<IRBasicType>(type)) + result.add(val); + else if (as<IRTorchTensorType>(type)) + result.add(val); + else if (auto vectorType = as<IRVectorType>(type)) + { + auto count = as<IRIntLit>(vectorType->getElementCount()); + if (!count) + { + return; + } + for (IRIntegerValue i = 0; i < count->getValue(); i++) + { + result.add(builder.emitElementExtract(vectorType->getElementType(), builder.getIntValue(builder.getIntType(), i))); + } + } + else if (auto arrayType = as<IRArrayType>(type)) + { + auto arraySize = as<IRIntLit>(arrayType->getElementCount()); + if (!arraySize) + { + return; + } + for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) + { + auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); + flattenToTupleImpl(result, builder, elementVal); + } + } + else if (auto structType = as<IRStructType>(type)) + { + for (auto field : structType->getFields()) + { + auto elementVal = builder.emitFieldExtract(field->getFieldType(), val, field->getKey()); + flattenToTupleImpl(result, builder, elementVal); + } + } +} + +static IRInst* flattenToHostReturnTuple(IRBuilder& builder, IRType* type, IRInst* val) +{ + List<IRInst*> vals; + flattenToTupleImpl(vals, builder, val); + return builder.emitMakeTargetTuple(type, (UInt)vals.getCount(), vals.getBuffer()); +} + +static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) +{ + IRBuilder builder(func); + + builder.setInsertBefore(func); + auto hostReturnType = getHostReturnType(builder, func->getResultType()); + if (!hostReturnType) + { + sink->diagnose(func->sourceLoc, Diagnostics::invalidTorchKernelReturnType, func->getResultType()); + return; + } + List<IRType*> hostParamTypes; + auto funcType = as<IRFuncType>(func->getDataType()); + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + hostParamTypes.add(funcType->getParamType(i)); + } + auto bindingFuncType = builder.getFuncType(hostParamTypes, hostReturnType); + func->setFullType(bindingFuncType); + + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType)); + + List<IRInst*> instsToRemove; + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (auto kernelDispatch = as<IRDispatchKernel>(inst)) + { + builder.setInsertBefore(kernelDispatch); + List<IRInst*> kernelArgs; + auto kernelArgCount = kernelDispatch->getArgCount(); + auto argArrayType = builder.getArrayType(builder.getPtrType(builder.getVoidType()), + builder.getIntValue(builder.getIntType(), kernelArgCount)); + auto argArrayVar = builder.emitVar(argArrayType); + for (UInt i = 0; i < kernelArgCount; i++) + { + auto arg = kernelDispatch->getArg(i); + auto argVar = builder.emitVar(arg->getFullType()); + builder.emitStore(argVar, arg); + auto addr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), i)); + builder.emitStore(addr, argVar); + } + auto argArrayPtr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), 0)); + builder.emitCudaKernelLaunch( + kernelDispatch->getBaseFn(), + kernelDispatch->getDispatchSize(), + kernelDispatch->getThreadGroupSize(), + argArrayPtr, + builder.emitGetTorchCudaStream()); + instsToRemove.add(inst); + } + else if (auto getView = as<IRTorchTensorGetView>(inst)) + { + builder.setInsertBefore(getView); + auto makeView = builder.emitMakeTensorView(getView->getFullType(), allocator, inst->getOperand(0)); + getView->replaceUsesWith(makeView); + instsToRemove.add(getView); + } + else if (auto ret = as<IRReturn>(inst)) + { + builder.setInsertBefore(ret); + auto retVal = flattenToHostReturnTuple(builder, hostReturnType, ret->getVal()); + ret->setOperand(0, retVal); + } + } + } + + for (auto inst : instsToRemove) + inst->removeAndDeallocate(); +} + +void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) +{ + List<IRFunc*> workList; + List<IRFunc*> cudaKernels; + for (auto globalInst : module->getGlobalInsts()) + { + auto func = as<IRFunc>(globalInst); + if (!func) + continue; + if (func->findDecoration<IRTorchEntryPointDecoration>()) + { + workList.add(func); + } + else if (func->findDecoration<IRCudaKernelDecoration>()) + { + cudaKernels.add(func); + } + else + { + // Remove all other export decorations if this is not a cuda host func. + if (auto decor = func->findDecoration<IRPublicDecoration>()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration<IRHLSLExportDecoration>()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration<IRKeepAliveDecoration>()) + decor->removeAndDeallocate(); + if (auto decor = func->findDecoration<IRDllExportDecoration>()) + decor->removeAndDeallocate(); + } + } + + for (auto func : workList) + generateCppBindingForFunc(func, sink); + + for (auto func : cudaKernels) + { + for (auto block = func->getFirstBlock(); block;) + { + auto nextBlock = block->getNextBlock(); + block->removeAndDeallocate(); + block = nextBlock; + } + } +} + +// Remove all [TorchEntryPoint] functions when emitting CUDA source. +void removeTorchKernels(IRModule* module) +{ + for (auto globalInst : module->getGlobalInsts()) + { + if (!as<IRFunc>(globalInst)) + continue; + if (globalInst->findDecoration<IRTorchEntryPointDecoration>()) + globalInst->removeAndDeallocate(); + } + +} + +} diff --git a/source/slang/slang-ir-pytorch-cpp-binding.h b/source/slang/slang-ir-pytorch-cpp-binding.h new file mode 100644 index 000000000..c35b6a8eb --- /dev/null +++ b/source/slang/slang-ir-pytorch-cpp-binding.h @@ -0,0 +1,12 @@ +#pragma once + +namespace Slang +{ +struct IRModule; +class DiagnosticSink; + +void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink); +void removeTorchKernels(IRModule* module); + +} + diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 69870c128..6ce54a948 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2594,6 +2594,11 @@ namespace Slang IRDynamicType* IRBuilder::getDynamicType() { return (IRDynamicType*)getType(kIROp_DynamicType); } + IRTargetTupleType* IRBuilder::getTargetTupleType(UInt count, IRType* const* types) + { + return (IRTargetTupleType*)getType(kIROp_TargetTupleType, count, (IRInst* const*)types); + } + IRAssociatedType* IRBuilder::getAssociatedType(ArrayView<IRInterfaceType*> constraintTypes) { return (IRAssociatedType*)getType(kIROp_AssociatedType, @@ -2788,6 +2793,27 @@ namespace Slang operands); } + IRArrayListType* IRBuilder::getArrayListType(IRType* elementType) + { + return (IRArrayListType*)getType( + kIROp_ArrayListType, + 1, + (IRInst**)&elementType); + } + + IRTensorViewType* IRBuilder::getTensorViewType(IRType* elementType) + { + return (IRTensorViewType*)getType( + kIROp_TensorViewType, + 1, + (IRInst**)&elementType); + } + + IRTorchTensorType* IRBuilder::getTorchTensorType() + { + return (IRTorchTensorType*)getType(kIROp_TorchTensorType, 0, nullptr); + } + IRDifferentialPairType* IRBuilder::getDifferentialPairType( IRType* valueType, IRInst* witnessTable) @@ -3173,6 +3199,21 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitCudaKernelLaunch(IRInst* baseFn, IRInst* gridDim, IRInst* blockDim, IRInst* argsArray, IRInst* cudaStream) + { + IRInst* args[5] = {baseFn, gridDim, blockDim, argsArray, cudaStream}; + return emitIntrinsicInst( + getVoidType(), + kIROp_CudaKernelLaunch, + 5, + args); + } + + IRInst* IRBuilder::emitGetTorchCudaStream() + { + return emitIntrinsicInst(getPtrType(getVoidType()), kIROp_TorchGetCudaStream, 0, nullptr); + } + IRInst* IRBuilder::emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn) { auto inst = createInst<IRBackwardDifferentiatePrimal>( @@ -3659,6 +3700,11 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeTuple, count, args); } + IRInst* IRBuilder::emitMakeTargetTuple(IRType* type, UInt count, IRInst* const* args) + { + return emitIntrinsicInst(type, kIROp_MakeTargetTuple, count, args); + } + IRInst* IRBuilder::emitMakeTuple(UInt count, IRInst* const* args) { List<IRType*> types; @@ -3851,6 +3897,11 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeArray, argCount, args); } + IRInst* IRBuilder::emitMakeArrayList(IRType* type, UInt argCount, IRInst* const* args) + { + return emitIntrinsicInst(type, kIROp_MakeArrayList, argCount, args); + } + IRInst* IRBuilder::emitMakeArrayFromElement( IRType* type, IRInst* element) @@ -3866,6 +3917,12 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeStruct, argCount, args); } + IRInst* IRBuilder::emitMakeTensorView(IRType* type, IRInst* allocator, IRInst* val) + { + IRInst* args[2] = { allocator, val }; + return emitIntrinsicInst(type, kIROp_MakeTensorView, 2, args); + } + IRInst* IRBuilder::emitMakeExistential( IRType* type, IRInst* value, diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 025812f83..d74a679d3 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1490,6 +1490,25 @@ struct IRMatrixType : IRType IR_LEAF_ISA(MatrixType) }; +struct IRArrayListType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + + IR_LEAF_ISA(ArrayListType) +}; + +struct IRTensorViewType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + + IR_LEAF_ISA(TensorViewType) +}; + +struct IRTorchTensorType : IRType +{ + IR_LEAF_ISA(TorchTensorType) +}; + struct IRSPIRVLiteralType : IRType { IR_LEAF_ISA(SPIRVLiteralType) @@ -1699,6 +1718,12 @@ struct IRTupleType : IRType IR_LEAF_ISA(TupleType) }; +/// Represents a tuple in target language. TargetTupleType will not be lowered to structs. +struct IRTargetTupleType : IRType +{ + IR_LEAF_ISA(TargetTupleType) +}; + /// Represents an `Result<T,E>`, used by functions that throws error codes. struct IRResultType : IRType { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 7144b3450..9d424d1e8 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1150,51 +1150,67 @@ static void addLinkageDecoration( { builder->addExportDecoration(inst, mangledName); } - if (decl->findModifier<PublicModifier>()) + for (auto modifier : decl->modifiers) { - builder->addPublicDecoration(inst); - builder->addKeepAliveDecoration(inst); - } - if (decl->findModifier<HLSLExportModifier>()) - { - builder->addHLSLExportDecoration(inst); - builder->addKeepAliveDecoration(inst); - } - if (decl->findModifier<ExternCppModifier>()) - { - builder->addExternCppDecoration(inst, mangledName); + if (as<PublicModifier>(modifier)) + { + builder->addPublicDecoration(inst); + builder->addKeepAliveDecoration(inst); + } + else if (as<HLSLExportModifier>(modifier)) + { + builder->addHLSLExportDecoration(inst); + builder->addKeepAliveDecoration(inst); + } + else if (as<ExternCppModifier>(modifier)) + { + builder->addExternCppDecoration(inst, mangledName); + } + else if (auto dllImportModifier = as<DllImportAttribute>(modifier)) + { + auto libraryName = dllImportModifier->modulePath; + auto functionName = dllImportModifier->functionName.getLength() + ? dllImportModifier->functionName.getUnownedSlice() + : decl->getName()->text.getUnownedSlice(); + builder->addDllImportDecoration(inst, libraryName.getUnownedSlice(), functionName); + } + else if (as<DllExportAttribute>(modifier)) + { + builder->addDllExportDecoration(inst, decl->getName()->text.getUnownedSlice()); + builder->addPublicDecoration(inst); + } + else if (as<CudaDeviceExportAttribute>(modifier)) + { + builder->addCudaDeviceExportDecoration(inst, decl->getName()->text.getUnownedSlice()); + builder->addPublicDecoration(inst); + builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); + } + else if (as<CudaHostAttribute>(modifier)) + { + builder->addCudaHostDecoration(inst); + builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); + } + else if (as<CudaKernelAttribute>(modifier)) + { + builder->addCudaKernelDecoration(inst); + builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); + builder->addPublicDecoration(inst); + builder->addKeepAliveDecoration(inst); + } + else if (as<TorchEntryPointAttribute>(modifier)) + { + builder->addTorchEntryPointDecoration(inst, decl->getName()->text.getUnownedSlice()); + builder->addCudaHostDecoration(inst); + builder->addPublicDecoration(inst); + builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); + } } if (as<InterfaceDecl>(decl->parentDecl) && - decl->parentDecl->hasModifier<ComInterfaceAttribute>()) + decl->parentDecl->hasModifier<ComInterfaceAttribute>() && + !inst->findDecoration<IRExternCppDecoration>()) { builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); } - if (auto dllImportModifier = decl->findModifier<DllImportAttribute>()) - { - auto libraryName = dllImportModifier->modulePath; - auto functionName = dllImportModifier->functionName.getLength() - ? dllImportModifier->functionName.getUnownedSlice() - : decl->getName()->text.getUnownedSlice(); - builder->addDllImportDecoration(inst, libraryName.getUnownedSlice(), functionName); - } - if (decl->findModifier<DllExportAttribute>()) - { - builder->addDllExportDecoration(inst, decl->getName()->text.getUnownedSlice()); - builder->addPublicDecoration(inst); - } - if (decl->findModifier<CudaDeviceExportAttribute>()) - { - builder->addCudaDeviceExportDecoration(inst, decl->getName()->text.getUnownedSlice()); - builder->addPublicDecoration(inst); - } - if (decl->findModifier<CudaHostAttribute>()) - { - builder->addCudaHostDecoration(inst); - } - if (decl->findModifier<CudaKernelAttribute>()) - { - builder->addCudaKernelDecoration(inst); - } } static void addLinkageDecoration( diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index 714e2c99d..d30c02484 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -2193,6 +2193,7 @@ struct OptionsParser if (rawOutputs.getCount() == 0 && rawTargets.getCount() == 1 && (rawTargets[0].format == CodeGenTarget::HostCPPSource || + rawTargets[0].format == CodeGenTarget::PyTorchCppBinding || rawTargets[0].format == CodeGenTarget::CUDASource || ArtifactDescUtil::makeDescForCompileTarget(asExternal(rawTargets[0].format)).kind == ArtifactKind::HostCallable)) { @@ -2258,7 +2259,7 @@ struct OptionsParser case CodeGenTarget::ShaderHostCallable: case CodeGenTarget::HostExecutable: case CodeGenTarget::ShaderSharedLibrary: - + case CodeGenTarget::PyTorchCppBinding: case CodeGenTarget::DXIL: rawOutput.isWholeProgram = true; diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index cdeb0b259..45f4be477 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2166,9 +2166,9 @@ namespace Slang dispatchExpr->baseFunction = parser->ParseArgExpr(); parser->ReadToken(TokenType::Comma); - dispatchExpr->threadGroupSize = parser->ParseArgExpr(); - parser->ReadToken(TokenType::Comma); dispatchExpr->dispatchSize = parser->ParseArgExpr(); + parser->ReadToken(TokenType::Comma); + dispatchExpr->threadGroupSize = parser->ParseArgExpr(); parser->ReadToken(TokenType::RParent); return dispatchExpr; diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 27aba435f..1c2726551 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -530,6 +530,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt matType->declRef = declRef; return matType; } + else if (magicMod->magicName == "TensorViewType") + { + SLANG_ASSERT(subst && subst->getArgs().getCount() == 1); + auto vecType = astBuilder->getOrCreate<TensorViewType>(ExtractGenericArgType(subst->getArgs()[0])); + vecType->declRef = declRef; + return vecType; + } else if (magicMod->magicName == "Texture") { SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1); diff --git a/source/slangc/main.cpp b/source/slangc/main.cpp index 2870a5a3c..2fe9d19a1 100644 --- a/source/slangc/main.cpp +++ b/source/slangc/main.cpp @@ -79,15 +79,15 @@ SLANG_TEST_TOOL_API SlangResult innerMain(StdWriters* stdWriters, slang::IGlobal if (TestToolUtil::hasDeferredStdLib(Index(argc - 1), argv + 1)) { SLANG_RETURN_ON_FAIL(slang_createGlobalSessionWithoutStdLib(SLANG_API_VERSION, session.writeRef())); - TestToolUtil::setSessionDefaultPreludeFromExePath(argv[0], session); } else if (!session) { // Just create the global session in the regular way if there isn't one set SLANG_RETURN_ON_FAIL(slang_createGlobalSession(SLANG_API_VERSION, session.writeRef())); - TestToolUtil::setSessionDefaultPreludeFromExePath(argv[0], session); } + TestToolUtil::setSessionDefaultPreludeFromExePath(argv[0], session); + SlangCompileRequest* compileRequest = spCreateCompileRequest(session); compileRequest->addSearchPath(Path::getParentDirectory(Path::getExecutablePath()).getBuffer()); SlangResult res = _compile(compileRequest, argc, argv); |
