diff options
| author | Yong He <yonghe@outlook.com> | 2024-04-17 21:32:28 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-04-17 21:32:28 -0700 |
| commit | 2c66cc7ef03b4d38fc463f2c8609a81232fcb91a (patch) | |
| tree | 7e100ddd0df91e8d7ae90c3335bb416bc50ad6ac | |
| parent | 4b3f554a58e4224806c31d66874fbe60f1f09332 (diff) | |
Add skeleton for metal backend. (#3971)
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj | 2 | ||||
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj.filters | 6 | ||||
| -rw-r--r-- | slang.h | 3 | ||||
| -rw-r--r-- | source/compiler-core/slang-artifact-desc-util.cpp | 1 | ||||
| -rw-r--r-- | source/core/slang-type-text-util.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-capabilities.capdef | 1 | ||||
| -rw-r--r-- | source/slang/slang-compiler.cpp | 7 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-doc-markdown-writer.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.cpp | 661 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.h | 81 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-profile.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-type-layout.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 4 | ||||
| -rw-r--r-- | tests/metal/simple-compute.slang | 10 | ||||
| -rw-r--r-- | tools/slang-test/slang-test-main.cpp | 1 |
21 files changed, 809 insertions, 3 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index 45f7abb80..f2d470d93 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -342,6 +342,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-emit-cuda.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-glsl.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-hlsl.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-emit-metal.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-precedence.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-source-writer.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-spirv-ops-debug-info-ext.h" />
@@ -575,6 +576,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClCompile Include="..\..\..\source\slang\slang-emit-cuda.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-glsl.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-hlsl.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-emit-metal.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-precedence.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-source-writer.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-spirv.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 9403d3649..ab06900ba 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -114,6 +114,9 @@ <ClInclude Include="..\..\..\source\slang\slang-emit-hlsl.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-emit-metal.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-emit-precedence.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -809,6 +812,9 @@ <ClCompile Include="..\..\..\source\slang\slang-emit-hlsl.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-emit-metal.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-emit-precedence.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -609,6 +609,7 @@ extern "C" SLANG_HOST_CPP_SOURCE, ///< C++ code for host library or executable. SLANG_HOST_HOST_CALLABLE, ///< Host callable host code (ie non kernel/shader) SLANG_CPP_PYTORCH_BINDING, ///< C++ PyTorch binding code. + SLANG_METAL, ///< Metal shading language SLANG_TARGET_COUNT_OF, }; @@ -641,6 +642,7 @@ extern "C" SLANG_PASS_THROUGH_NVRTC, ///< NVRTC Cuda compiler SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler' - includes LLVM and Clang SLANG_PASS_THROUGH_SPIRV_OPT, ///< SPIRV-opt + SLANG_PASS_THROUGH_METAL, ///< Metal compiler SLANG_PASS_THROUGH_COUNT_OF, }; @@ -743,6 +745,7 @@ extern "C" SLANG_SOURCE_LANGUAGE_CPP, SLANG_SOURCE_LANGUAGE_CUDA, SLANG_SOURCE_LANGUAGE_SPIRV, + SLANG_SOURCE_LANGUAGE_METAL, SLANG_SOURCE_LANGUAGE_COUNT_OF, }; diff --git a/source/compiler-core/slang-artifact-desc-util.cpp b/source/compiler-core/slang-artifact-desc-util.cpp index 559a1b4ca..2646555bc 100644 --- a/source/compiler-core/slang-artifact-desc-util.cpp +++ b/source/compiler-core/slang-artifact-desc-util.cpp @@ -286,6 +286,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL case SLANG_PTX: return Desc::make(Kind::Executable, Payload::PTX, Style::Kernel, 0); case SLANG_OBJECT_CODE: return Desc::make(Kind::ObjectCode, Payload::HostCPU, Style::Kernel, 0); case SLANG_HOST_HOST_CALLABLE: return Desc::make(Kind::HostCallable, Payload::HostCPU, Style::Host, 0); + case SLANG_METAL: return Desc::make(Kind::Source, Payload::Metal, Style::Kernel, 0); default: break; } diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp index 676be3976..7e1ffc439 100644 --- a/source/core/slang-type-text-util.cpp +++ b/source/core/slang-type-text-util.cpp @@ -61,6 +61,7 @@ static const TypeTextUtil::CompileTargetInfo s_compileTargetInfos[] = { SLANG_SHADER_HOST_CALLABLE, "", "host-callable,callable", "Host callable" }, { SLANG_OBJECT_CODE, "obj,o", "object-code", "Object code" }, { SLANG_HOST_HOST_CALLABLE, "", "host-host-callable", "Host callable for host execution" }, + { SLANG_METAL, "metal", "metal", "Metal shader source"}, }; static const NamesDescriptionValue s_languageInfos[] = diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index b83585c5c..eb546ae6b 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -45,6 +45,7 @@ def glsl : target + textualTarget; def c : target + textualTarget; def cpp : target + textualTarget; def cuda : target + textualTarget; +def metal : target + textualTarget; // We have multiple capabilities for the various SPIR-V versions, // arranged so that they inherit from one another to represent which versions diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 3f5351ff4..04536c81a 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -480,7 +480,10 @@ namespace Slang { return SourceLanguage::SPIRV; } - + case PassThroughMode::MetalC: + { + return SourceLanguage::Metal; + } default: break; } SLANG_ASSERT(!"Unknown compiler"); @@ -499,6 +502,7 @@ namespace Slang case CodeGenTarget::HostCPPSource: case CodeGenTarget::PyTorchCppBinding: case CodeGenTarget::CSource: + case CodeGenTarget::Metal: { return PassThroughMode::None; } @@ -1617,6 +1621,7 @@ namespace Slang case CodeGenTarget::HostCPPSource: case CodeGenTarget::PyTorchCppBinding: case CodeGenTarget::CSource: + case CodeGenTarget::Metal: { RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 014b678f5..7f7903f8b 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -92,6 +92,7 @@ namespace Slang CUDAObjectCode = SLANG_CUDA_OBJECT_CODE, ObjectCode = SLANG_OBJECT_CODE, HostHostCallable = SLANG_HOST_HOST_CALLABLE, + Metal = SLANG_METAL, CountOf = SLANG_TARGET_COUNT_OF, }; @@ -1236,6 +1237,7 @@ namespace Slang NVRTC = SLANG_PASS_THROUGH_NVRTC, ///< NVRTC CUDA compiler LLVM = SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler' SpirvOpt = SLANG_PASS_THROUGH_SPIRV_OPT, ///< pass thorugh spirv to spirv-opt + MetalC = SLANG_PASS_THROUGH_METAL, CountOf = SLANG_PASS_THROUGH_COUNT_OF, }; void printDiagnosticArg(StringBuilder& sb, PassThroughMode val); diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp index 77b239437..ac3b9ca7e 100644 --- a/source/slang/slang-doc-markdown-writer.cpp +++ b/source/slang/slang-doc-markdown-writer.cpp @@ -479,7 +479,10 @@ static DocMarkdownWriter::Requirement _getRequirementFromTargetToken(const Token { return Requirement{ CodeGenTarget::CSource, targetName }; } - + else if (isCapabilityDerivedFrom(targetCap, CapabilityAtom::metal)) + { + return Requirement{ CodeGenTarget::Metal, targetName }; + } return Requirement{ CodeGenTarget::Unknown, String() }; } diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index ceb4e6de0..7cb4871be 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -91,6 +91,10 @@ struct CLikeSourceEmitter::ComputeEmitActionsContext { return SourceLanguage::CUDA; } + case CodeGenTarget::Metal: + { + return SourceLanguage::Metal; + } } } @@ -3699,7 +3703,6 @@ void CLikeSourceEmitter::emitVarModifiers(IRVarLayout* layout, IRInst* varDecl, { // TODO(JS): We could push all of this onto the target impls, and then not need so many virtual hooks. emitVarDecorationsImpl(varDecl); - emitTempModifiers(varDecl); if (!layout) diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp new file mode 100644 index 000000000..1ce25a2da --- /dev/null +++ b/source/slang/slang-emit-metal.cpp @@ -0,0 +1,661 @@ +// slang-emit-metal.cpp +#include "slang-emit-metal.h" + +#include "../core/slang-writer.h" + +#include "slang-ir-util.h" +#include "slang-emit-source-writer.h" +#include "slang-mangled-lexer.h" + +#include <assert.h> + +namespace Slang { + +void MetalSourceEmitter::_emitHLSLDecorationSingleString(const char* name, IRFunc* entryPoint, IRStringLit* val) +{ + SLANG_UNUSED(entryPoint); + assert(val); + + m_writer->emit("[["); + m_writer->emit(name); + m_writer->emit("(\""); + m_writer->emit(val->getStringSlice()); + m_writer->emit("\")]]\n"); +} + +void MetalSourceEmitter::_emitHLSLDecorationSingleInt(const char* name, IRFunc* entryPoint, IRIntLit* val) +{ + SLANG_UNUSED(entryPoint); + SLANG_ASSERT(val); + + auto intVal = getIntVal(val); + + m_writer->emit("[["); + m_writer->emit(name); + m_writer->emit("("); + m_writer->emit(intVal); + m_writer->emit(")]]\n"); +} + +void MetalSourceEmitter::_emitHLSLRegisterSemantic(LayoutResourceKind kind, EmitVarChain* chain, IRInst* inst, char const* uniformSemanticSpelling) +{ + // Metal does not use explicit binding. + SLANG_UNUSED(kind); + SLANG_UNUSED(chain); + SLANG_UNUSED(inst); + SLANG_UNUSED(uniformSemanticSpelling); +} + +void MetalSourceEmitter::_emitHLSLRegisterSemantics(EmitVarChain* chain, IRInst* inst, char const* uniformSemanticSpelling) +{ + // TODO: implement. + SLANG_UNUSED(chain); + SLANG_UNUSED(inst); + SLANG_UNUSED(uniformSemanticSpelling); +} + +void MetalSourceEmitter::_emitHLSLRegisterSemantics(IRVarLayout* varLayout, IRInst* inst, char const* uniformSemanticSpelling) +{ + // TODO: implement. + SLANG_UNUSED(varLayout); + SLANG_UNUSED(inst); + SLANG_UNUSED(uniformSemanticSpelling); +} + +void MetalSourceEmitter::_emitHLSLParameterGroupFieldLayoutSemantics(EmitVarChain* chain) +{ + // TODO: implement. + SLANG_UNUSED(chain); +} + +void MetalSourceEmitter::_emitHLSLParameterGroupFieldLayoutSemantics(IRVarLayout* fieldLayout, EmitVarChain* inChain) +{ + // TODO: implement. + SLANG_UNUSED(fieldLayout); + SLANG_UNUSED(inChain); +} + +void MetalSourceEmitter::_emitHLSLParameterGroup(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) +{ + // Metal does not allow shader parameters declared as global variables, so we shouldn't see this. + SLANG_UNUSED(varDecl); + SLANG_UNUSED(type); + SLANG_ASSERT(!"Metal does not allow shader parameters declared as global variables."); +} + +void MetalSourceEmitter::_emitHLSLTextureType(IRTextureTypeBase* texType) +{ + if (getIntVal(texType->getIsShadowInst()) != 0) + { + m_writer->emit("depth"); + } + else + { + m_writer->emit("texture"); + } + + switch (texType->GetBaseShape()) + { + case SLANG_TEXTURE_1D: m_writer->emit("1d"); break; + case SLANG_TEXTURE_2D: m_writer->emit("2d"); break; + case SLANG_TEXTURE_3D: m_writer->emit("3d"); break; + case SLANG_TEXTURE_CUBE: m_writer->emit("cube"); break; + case SLANG_TEXTURE_BUFFER: m_writer->emit("1d"); break; + default: + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled resource shape"); + break; + } + + if (texType->isMultisample()) + { + m_writer->emit("_ms"); + } + if (texType->isArray()) + { + m_writer->emit("_array"); + } + m_writer->emit("<"); + emitType(getVectorElementType(texType->getElementType())); + m_writer->emit(", "); + + switch (texType->getAccess()) + { + case SLANG_RESOURCE_ACCESS_READ: + m_writer->emit("access::sample"); + break; + + case SLANG_RESOURCE_ACCESS_READ_WRITE: + case SLANG_RESOURCE_ACCESS_APPEND: + case SLANG_RESOURCE_ACCESS_CONSUME: + case SLANG_RESOURCE_ACCESS_FEEDBACK: + case SLANG_RESOURCE_ACCESS_RASTER_ORDERED: + m_writer->emit("access::read_write"); + break; + default: + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled resource access mode"); + break; + } + + m_writer->emit(">"); +} + +void MetalSourceEmitter::_emitHLSLSubpassInputType(IRSubpassInputType* subpassType) +{ + SLANG_UNUSED(subpassType); +} + +void MetalSourceEmitter::emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling) +{ + auto layout = getVarLayout(inst); + if (layout) + { + _emitHLSLRegisterSemantics(layout, inst, uniformSemanticSpelling); + } +} + +void MetalSourceEmitter::emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) +{ + _emitHLSLParameterGroup(varDecl, type); +} + +void MetalSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPointDecoration* entryPointDecor) +{ + auto profile = m_effectiveProfile; + auto stage = entryPointDecor->getProfile().getStage(); + + switch (stage) + { + case Stage::Fragment: + m_writer->emit("[[fragment]] "); + break; + case Stage::Vertex: + m_writer->emit("[[vertex]] "); + break; + case Stage::Compute: + m_writer->emit("[[kernel]] "); + break; + default: + SLANG_ABORT_COMPILATION("unsupported stage."); + } + + switch (stage) + { + case Stage::Pixel: + { + if (irFunc->findDecoration<IREarlyDepthStencilDecoration>()) + { + m_writer->emit("[[early_fragment_tests]]\n"); + } + break; + } + default: + break; + } +} + +bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) +{ + switch (inst->getOp()) + { + case kIROp_MakeVector: + case kIROp_MakeMatrix: + { + if (inst->getOperandCount() == 1) + { + EmitOpInfo outerPrec = inOuterPrec; + bool needClose = false; + + auto prec = getInfo(EmitOp::Prefix); + needClose = maybeEmitParens(outerPrec, prec); + + // Need to emit as cast for HLSL + emitType(inst->getDataType()); + m_writer->emit("("); + emitOperand(inst->getOperand(0), rightSide(outerPrec, prec)); + m_writer->emit(") "); + + maybeCloseParens(needClose); + // Handled + return true; + } + break; + } + case kIROp_BitCast: + { + auto toType = inst->getDataType(); + + m_writer->emit("as_type<"); + emitType(toType); + m_writer->emit(">("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + case kIROp_StringLit: + { + const auto handler = StringEscapeUtil::getHandler(StringEscapeUtil::Style::Slang); + + StringBuilder buf; + const UnownedStringSlice slice = as<IRStringLit>(inst)->getStringSlice(); + StringEscapeUtil::appendQuoted(handler, slice, buf); + + m_writer->emit(buf); + + return true; + } + case kIROp_ByteAddressBufferLoad: + { + // This only works for loads of 4-byte values. + // Other element types should have been lowered by previous legalization passes. + auto elementType = inst->getDataType(); + auto buffer = inst->getOperand(0); + auto offset = inst->getOperand(1); + m_writer->emit("as_type<"); + emitType(elementType); + m_writer->emit(">("); + emitOperand(buffer, getInfo(EmitOp::General)); + m_writer->emit("[("); + emitOperand(offset, getInfo(EmitOp::General)); + m_writer->emit(")>>2)]"); + return true; + } + case kIROp_ByteAddressBufferStore: + { + // This only works for loads of 4-byte values. + // Other element types should have been lowered by previous legalization passes. + auto buffer = inst->getOperand(0); + auto offset = inst->getOperand(1); + emitOperand(buffer, getInfo(EmitOp::General)); + m_writer->emit("[("); + emitOperand(offset, getInfo(EmitOp::General)); + m_writer->emit(")>>2)] = as_type<uint32_t>("); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + break; + + default: break; + } + // Not handled + return false; +} + +void MetalSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) +{ + // In some cases we *need* to use the built-in syntax sugar for vector types, + // so we will try to emit those whenever possible. + // + if( elementCount >= 1 && elementCount <= 4 ) + { + switch( elementType->getOp() ) + { + case kIROp_FloatType: + case kIROp_IntType: + case kIROp_UIntType: + // TODO: There are more types that need to be covered here + emitType(elementType); + m_writer->emit(elementCount); + return; + + default: + break; + } + } + + // As a fallback, we will use the `vector<...>` type constructor, + // although we should not expect to run into types that don't + // have a sugared form. + // + m_writer->emit("vector<"); + emitType(elementType); + m_writer->emit(","); + m_writer->emit(elementCount); + m_writer->emit(">"); +} + +void MetalSourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) +{ + switch (decl->getMode()) + { + case kIRLoopControl_Unroll: + m_writer->emit("[unroll]\n"); + break; + case kIRLoopControl_Loop: + m_writer->emit("[loop]\n"); + break; + default: + break; + } +} + +static bool _canEmitExport(const Profile& profile) +{ + const auto family = profile.getFamily(); + const auto version = profile.getVersion(); + // Is ita late enough version of shader model to output with 'export' + return (family == ProfileFamily::DX && version >= ProfileVersion::DX_6_1); +} + +/* virtual */void MetalSourceEmitter::emitFuncDecorationsImpl(IRFunc* func) +{ + // Specially handle export, as we don't want to emit it multiple times + if (getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram) && + _canEmitExport(m_effectiveProfile)) + { + for (auto decoration : func->getDecorations()) + { + const auto op = decoration->getOp(); + if (op == kIROp_PublicDecoration || + op == kIROp_HLSLExportDecoration) + { + m_writer->emit("export\n"); + break; + } + } + } + + // Use the default for others + Super::emitFuncDecorationsImpl(func); +} + +void MetalSourceEmitter::emitIfDecorationsImpl(IRIfElse* ifInst) +{ + // Does not apply to metal. + SLANG_UNUSED(ifInst); +} + +void MetalSourceEmitter::emitSwitchDecorationsImpl(IRSwitch* switchInst) +{ + // Does not apply to metal. + SLANG_UNUSED(switchInst); +} + +void MetalSourceEmitter::emitFuncDecorationImpl(IRDecoration* decoration) +{ + // Does not apply to metal. + SLANG_UNUSED(decoration); +} + +void MetalSourceEmitter::emitSimpleValueImpl(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_FloatLit: + { + IRConstant* constantInst = static_cast<IRConstant*>(inst); + IRConstant::FloatKind kind = constantInst->getFloatKind(); + switch (kind) + { + case IRConstant::FloatKind::Nan: + { + m_writer->emit("(0.0 / 0.0)"); + return; + } + case IRConstant::FloatKind::PositiveInfinity: + { + m_writer->emit("(1.0 / 0.0)"); + return; + } + case IRConstant::FloatKind::NegativeInfinity: + { + m_writer->emit("(-1.0 / 0.0)"); + return; + } + default: break; + } + break; + } + + default: break; + } + + Super::emitSimpleValueImpl(inst); +} + +void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) +{ + switch (type->getOp()) + { + case kIROp_VoidType: + case kIROp_BoolType: + case kIROp_Int8Type: + case kIROp_IntType: + case kIROp_Int64Type: + case kIROp_UInt8Type: + case kIROp_UIntType: + case kIROp_UInt64Type: + case kIROp_FloatType: + case kIROp_DoubleType: + case kIROp_Int16Type: + case kIROp_UInt16Type: + case kIROp_HalfType: + { + m_writer->emit(getDefaultBuiltinTypeName(type->getOp())); + return; + } + case kIROp_IntPtrType: + m_writer->emit("int64_t"); + return; + case kIROp_UIntPtrType: + m_writer->emit("uint64_t"); + return; + case kIROp_StructType: + m_writer->emit(getName(type)); + return; + + case kIROp_VectorType: + { + auto vecType = (IRVectorType*)type; + emitVectorTypeNameImpl(vecType->getElementType(), getIntVal(vecType->getElementCount())); + return; + } + case kIROp_MatrixType: + { + auto matType = (IRMatrixType*)type; + + // Similar to GLSL, Metal's column-major is really our row-major. + m_writer->emit("matrix<"); + emitType(matType->getElementType()); + m_writer->emit(","); + emitVal(matType->getColumnCount(), getInfo(EmitOp::General)); + m_writer->emit(","); + emitVal(matType->getRowCount(), getInfo(EmitOp::General)); + m_writer->emit("> "); + return; + } + case kIROp_SamplerStateType: + case kIROp_SamplerComparisonStateType: + { + m_writer->emit("sampler"); + return; + } + case kIROp_NativeStringType: + case kIROp_StringType: + { + m_writer->emit("int"); + return; + } + case kIROp_ParameterBlockType: + case kIROp_ConstantBufferType: + { + m_writer->emit("constant "); + emitType((IRType*)type->getOperand(0)); + m_writer->emit("*"); + return; + } + default: break; + } + + if (auto texType = as<IRTextureType>(type)) + { + _emitHLSLTextureType(texType); + return; + } + else if (auto imageType = as<IRGLSLImageType>(type)) + { + _emitHLSLTextureType(imageType); + return; + } + else if (auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type)) + { + m_writer->emit("device "); + emitType(structuredBufferType->getElementType()); + m_writer->emit("*"); + return; + } + else if (const auto untypedBufferType = as<IRUntypedBufferResourceType>(type)) + { + switch (type->getOp()) + { + case kIROp_HLSLByteAddressBufferType: + case kIROp_HLSLRWByteAddressBufferType: + case kIROp_HLSLRasterizerOrderedByteAddressBufferType: + m_writer->emit("device "); + m_writer->emit("uint32_t *"); + break; + case kIROp_RaytracingAccelerationStructureType: m_writer->emit("acceleration_structure<instancing>"); break; + default: + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled buffer type"); + break; + } + return; + } + else if(auto specializedType = as<IRSpecialize>(type)) + { + // If a `specialize` instruction made it this far, then + // it represents an intrinsic generic type. + // + emitSimpleType((IRType*) getSpecializedValue(specializedType)); + m_writer->emit("<"); + UInt argCount = specializedType->getArgCount(); + for (UInt ii = 0; ii < argCount; ++ii) + { + if (ii != 0) m_writer->emit(", "); + emitVal(specializedType->getArg(ii), getInfo(EmitOp::General)); + } + m_writer->emit(" >"); + return; + } + + // HACK: As a fallback for HLSL targets, assume that the name of the + // instruction being used is the same as the name of the HLSL type. + { + auto opInfo = getIROpInfo(type->getOp()); + m_writer->emit(opInfo.name); + UInt operandCount = type->getOperandCount(); + if (operandCount) + { + m_writer->emit("<"); + for (UInt ii = 0; ii < operandCount; ++ii) + { + if (ii != 0) m_writer->emit(", "); + emitVal(type->getOperand(ii), getInfo(EmitOp::General)); + } + m_writer->emit(" >"); + } + } +} + +void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace) +{ + if (as<IRGroupSharedRate>(rate)) + { + m_writer->emit("threadgroup "); + } +} + +void MetalSourceEmitter::emitSemanticsImpl(IRInst* inst, bool allowOffsets) +{ + // Metal does not use semantics. + SLANG_UNUSED(inst); + SLANG_UNUSED(allowOffsets); +} + +void MetalSourceEmitter::_emitStageAccessSemantic(IRStageAccessDecoration* decoration, const char* name) +{ + SLANG_UNUSED(decoration); + SLANG_UNUSED(name); +} + +void MetalSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) +{ + Super::emitSimpleFuncParamImpl(param); +} + +static UnownedStringSlice _getInterpolationModifierText(IRInterpolationMode mode) +{ + switch (mode) + { + case IRInterpolationMode::PerVertex: + case IRInterpolationMode::NoInterpolation: return UnownedStringSlice::fromLiteral("[[flat]]"); + case IRInterpolationMode::NoPerspective: return UnownedStringSlice::fromLiteral("[[center_no_perspective]]"); + case IRInterpolationMode::Linear: return UnownedStringSlice::fromLiteral("[[sample_no_perspective]]"); + case IRInterpolationMode::Sample: return UnownedStringSlice::fromLiteral("[[sample_perspective]]"); + case IRInterpolationMode::Centroid: return UnownedStringSlice::fromLiteral("[[center_perspective]]"); + default: return UnownedStringSlice(); + } +} + +void MetalSourceEmitter::emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) +{ + SLANG_UNUSED(layout); + SLANG_UNUSED(valueType); + + for (auto dd : varInst->getDecorations()) + { + if (dd->getOp() != kIROp_InterpolationModeDecoration) + continue; + + auto decoration = (IRInterpolationModeDecoration*)dd; + + UnownedStringSlice modeText = _getInterpolationModifierText(decoration->getMode()); + if (modeText.getLength() > 0) + { + m_writer->emit(modeText); + m_writer->emitChar(' '); + } + } +} + +void MetalSourceEmitter::emitPackOffsetModifier(IRInst* varInst, IRType* valueType, IRPackOffsetDecoration* layout) +{ + SLANG_UNUSED(varInst); + SLANG_UNUSED(valueType); + SLANG_UNUSED(layout); + // We emit packoffset as a semantic in `emitSemantic`, so nothing to do here. +} + +void MetalSourceEmitter::emitMeshShaderModifiersImpl(IRInst* varInst) +{ + SLANG_UNUSED(varInst); +} + +void MetalSourceEmitter::emitVarDecorationsImpl(IRInst* varInst) +{ + SLANG_UNUSED(varInst); +} + +void MetalSourceEmitter::emitMatrixLayoutModifiersImpl(IRVarLayout*) +{ + // Metal only supports column major layout, and we must have + // already translated all matrix ops to assume column-major + // at this stage. +} + +void MetalSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) +{ + SLANG_UNUSED(inst); +} + +void MetalSourceEmitter::emitFrontMatterImpl(TargetRequest*) +{ + +} + +void MetalSourceEmitter::emitGlobalInstImpl(IRInst* inst) +{ + Super::emitGlobalInstImpl(inst); +} + +} // namespace Slang diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h new file mode 100644 index 000000000..4c4f27be3 --- /dev/null +++ b/source/slang/slang-emit-metal.h @@ -0,0 +1,81 @@ +// slang-emit-metal.h +#ifndef SLANG_EMIT_METAL_H +#define SLANG_EMIT_METAL_H + +#include "slang-emit-c-like.h" + +namespace Slang +{ +class MetalExtensionTracker : public ExtensionTracker {}; + +class MetalSourceEmitter : public CLikeSourceEmitter +{ +public: + typedef CLikeSourceEmitter Super; + + MetalSourceEmitter(const Desc& desc) + : Super(desc) + , m_extensionTracker(new MetalExtensionTracker()) + {} + + virtual RefObject* getExtensionTracker() SLANG_OVERRIDE { return m_extensionTracker; } + +protected: + RefPtr<MetalExtensionTracker> m_extensionTracker; + + virtual void emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling) SLANG_OVERRIDE; + virtual void emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) SLANG_OVERRIDE; + virtual void emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE; + + virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE; + + virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE; + virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsets) SLANG_OVERRIDE; + virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE; + virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE; + virtual void emitPackOffsetModifier(IRInst* varInst, IRType* valueType, IRPackOffsetDecoration* decoration) SLANG_OVERRIDE; + + virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE; + virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; + virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE; + virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE; + virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE; + + virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE; + virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE; + virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE; + virtual void emitFuncDecorationImpl(IRDecoration* decoration) SLANG_OVERRIDE; + virtual void emitFuncDecorationsImpl(IRFunc* func) SLANG_OVERRIDE; + + virtual void emitSwitchDecorationsImpl(IRSwitch* switchInst) SLANG_OVERRIDE; + virtual void emitIfDecorationsImpl(IRIfElse* ifInst) SLANG_OVERRIDE; + + virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE; + + virtual void emitGlobalInstImpl(IRInst* inst) SLANG_OVERRIDE; + + // Emit a single `register` semantic, as appropriate for a given resource-type-specific layout info + // Keyword to use in the uniform case (`register` for globals, `packoffset` inside a `cbuffer`) + void _emitHLSLRegisterSemantic(LayoutResourceKind kind, EmitVarChain* chain, IRInst* inst, char const* uniformSemanticSpelling = "register"); + + // Emit all the `register` semantics that are appropriate for a particular variable layout + void _emitHLSLRegisterSemantics(EmitVarChain* chain, IRInst* inst, char const* uniformSemanticSpelling = "register"); + void _emitHLSLRegisterSemantics(IRVarLayout* varLayout, IRInst* inst, char const* uniformSemanticSpelling = "register"); + + void _emitHLSLParameterGroupFieldLayoutSemantics(EmitVarChain* chain); + void _emitHLSLParameterGroupFieldLayoutSemantics(IRVarLayout* fieldLayout, EmitVarChain* inChain); + + void _emitHLSLParameterGroup(IRGlobalParam* varDecl, IRUniformParameterGroupType* type); + + void _emitHLSLTextureType(IRTextureTypeBase* texType); + + void _emitHLSLSubpassInputType(IRSubpassInputType* subpassType); + + void _emitHLSLDecorationSingleString(const char* name, IRFunc* entryPoint, IRStringLit* val); + void _emitHLSLDecorationSingleInt(const char* name, IRFunc* entryPoint, IRIntLit* val); + + void _emitStageAccessSemantic(IRStageAccessDecoration* decoration, const char* name); +}; + +} +#endif diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index fbc99b0ce..633d17345 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -94,6 +94,7 @@ #include "slang-emit-glsl.h" #include "slang-emit-hlsl.h" +#include "slang-emit-metal.h" #include "slang-emit-cpp.h" #include "slang-emit-cuda.h" #include "slang-emit-torch.h" @@ -1171,6 +1172,11 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr sourceEmitter = new CUDASourceEmitter(desc); break; } + case SourceLanguage::Metal: + { + sourceEmitter = new MetalSourceEmitter(desc); + break; + } default: break; } break; diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index defca19fd..86e255705 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -1467,6 +1467,7 @@ static bool doesTargetAllowUnresolvedFuncSymbol(TargetRequest* req) switch (req->getTarget()) { case CodeGenTarget::HLSL: + case CodeGenTarget::Metal: case CodeGenTarget::DXIL: case CodeGenTarget::DXILAssembly: case CodeGenTarget::HostCPPSource: diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 380d1141e..eaf254190 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -16,6 +16,13 @@ bool isPointerOfType(IRInst* type, IROp opCode) return false; } +IRType* getVectorElementType(IRType* type) +{ + if (auto vectorType = as<IRVectorType>(type)) + return vectorType->getElementType(); + return type; +} + Dictionary<IRInst*, IRInst*> buildInterfaceRequirementDict(IRInterfaceType* interfaceType) { Dictionary<IRInst*, IRInst*> result; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 40ba783b9..94ae4bc9f 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -75,6 +75,8 @@ Dictionary<IRInst*, IRInst*> buildInterfaceRequirementDict(IRInterfaceType* inte bool isComInterfaceType(IRType* type); +// If `type` is a vector, returns its element type. Otherwise, return `type`. +IRType* getVectorElementType(IRType* type); IROp getTypeStyle(IROp op); IROp getTypeStyle(BaseType op); diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h index a1c08fe6a..bd6feab23 100644 --- a/source/slang/slang-profile.h +++ b/source/slang/slang-profile.h @@ -18,6 +18,7 @@ namespace Slang CPP = SLANG_SOURCE_LANGUAGE_CPP, CUDA = SLANG_SOURCE_LANGUAGE_CUDA, SPIRV = SLANG_SOURCE_LANGUAGE_SPIRV, + Metal = SLANG_SOURCE_LANGUAGE_METAL, CountOf = SLANG_SOURCE_LANGUAGE_COUNT_OF, }; diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 30c41d0fe..e79c51256 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1537,6 +1537,7 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe case CodeGenTarget::ShaderSharedLibrary: case CodeGenTarget::CPPSource: case CodeGenTarget::CSource: + case CodeGenTarget::Metal: { // For now lets use some fairly simple CPU binding rules @@ -1788,6 +1789,10 @@ SourceLanguage getIntermediateSourceLanguageForTarget(TargetProgram* targetProgr // Currently DXBytecode and DXIL are generated via HLSL return SourceLanguage::HLSL; } + case CodeGenTarget::Metal: + { + return SourceLanguage::Metal; + } case CodeGenTarget::CSource: { return SourceLanguage::C; diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 9d018d47f..437f0cd63 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1709,6 +1709,10 @@ CapabilitySet TargetRequest::getTargetCaps() atoms.add(CapabilityName::cuda); break; + case CodeGenTarget::Metal: + atoms.add(CapabilityName::metal); + break; + default: break; } diff --git a/tests/metal/simple-compute.slang b/tests/metal/simple-compute.slang new file mode 100644 index 000000000..e099704be --- /dev/null +++ b/tests/metal/simple-compute.slang @@ -0,0 +1,10 @@ +//TEST:SIMPLE(filecheck=CHECK): -target metal + +RWStructuredBuffer<float> outputBuffer; + +// CHECK: {{.*}}kernel{{.*}} void main() +[numthreads(1,1,1)] +void main() +{ + outputBuffer[0] = 1.0f; +}
\ No newline at end of file diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index c93bebf33..cb77c8830 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -967,6 +967,7 @@ static PassThroughFlags _getPassThroughFlagsForTarget(SlangCompileTarget target) case SLANG_CPP_PYTORCH_BINDING: case SLANG_HOST_CPP_SOURCE: case SLANG_CUDA_SOURCE: + case SLANG_METAL: { return 0; } |
