From 26a0b3e04689fee1ec9ec071eacd72faf1efe4eb Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 8 Sep 2023 15:57:00 -0700 Subject: Fix attribute highlighting + language server crash. (#3198) * Fix attribute highlighting + language server crash. * Fix wave intrinsic. * Fix. * Fix. --------- Co-authored-by: Yong He --- source/slang/hlsl.meta.slang | 294 ++++++++------------- source/slang/slang-ast-modifier.h | 3 + source/slang/slang-check-expr.cpp | 5 +- source/slang/slang-check-modifier.cpp | 1 + source/slang/slang-language-server-ast-lookup.cpp | 4 +- .../slang-language-server-semantic-tokens.cpp | 3 +- source/slang/slang-language-server.cpp | 1 + source/slang/slang-parser.cpp | 10 +- 8 files changed, 127 insertions(+), 194 deletions(-) (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index babd16f6e..c670f234e 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -5624,7 +5624,7 @@ T WaveMaskProduct(WaveMask mask, T expr) { // TODO: use the correct integer width OpBitcast $$uint %uvalue $expr; - OpGroupNonUniformIMul $$T %mulResult Subgroup 0 %uvalue; + OpGroupNonUniformIMul $$uint %mulResult Subgroup 0 %uvalue; OpBitcast $$T result %mulResult }; } @@ -5687,7 +5687,7 @@ T WaveMaskSum(WaveMask mask, T expr) { // TODO: use the correct integer width OpBitcast $$uint %uvalue $expr; - OpGroupNonUniformIAdd $$T %mulResult Subgroup 0 %uvalue; + OpGroupNonUniformIAdd $$uint %mulResult Subgroup 0 %uvalue; OpBitcast $$T result %mulResult }; } @@ -6120,62 +6120,27 @@ __generic T QuadReadAcrossDiagonal(T localValue); __generic vector QuadReadAcrossDiagonal(vector localValue); __generic matrix QuadReadAcrossDiagonal(matrix localValue); +// WaveActiveBitAnd, WaveActiveBitOr, WaveActiveBitXor +${{{{ +struct WaveActiveBitOpEntry { const char* hlslName; const char* glslName; const char* spirvName; }; +const WaveActiveBitOpEntry kWaveActiveBitOpEntries[] = {{"BitAnd", "And", "BitwiseAnd"}, {"BitOr", "Or", "BitwiseOr"}, {"BitXor", "Xor", "BitwiseXor"}}; +for (auto opName : kWaveActiveBitOpEntries) { +}}}} __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __spirv_capability(GroupNonUniformArithmetic) -T WaveActiveBitAnd(T expr) -{ - __target_switch - { - case glsl: __intrinsic_asm "subgroupAnd($0)"; - case hlsl: __intrinsic_asm "WaveActiveBitAnd"; - case spirv: - return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup Reduce $expr}; - default: - return WaveMaskBitAnd(WaveGetActiveMask(), expr); - } -} - -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) -vector WaveActiveBitAnd(vector expr) -{ - __target_switch - { - case glsl: __intrinsic_asm "subgroupAnd($0)"; - case hlsl: __intrinsic_asm "WaveActiveBitAnd"; - case spirv: - return spirv_asm {OpGroupNonUniformBitwiseAnd $$vector result Subgroup Reduce $expr}; - default: - return WaveMaskBitAnd(WaveGetActiveMask(), expr); - } -} - -__generic -__target_intrinsic(hlsl) -matrix WaveActiveBitAnd(matrix expr) -{ - return WaveMaskBitAnd(WaveGetActiveMask(), expr); -} - -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) -T WaveActiveBitOr(T expr) +T WaveActive$(opName.hlslName)(T expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupOr($0)"; - case hlsl: __intrinsic_asm "WaveActiveBitOr"; + case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)"; + case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseOr $$T result Subgroup Reduce $expr}; + return spirv_asm {OpGroupNonUniform$(opName.spirvName) $$T result Subgroup Reduce $expr}; default: - return WaveMaskBitOr(WaveGetActiveMask(), expr); + return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } } @@ -6183,86 +6148,54 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __spirv_capability(GroupNonUniformArithmetic) -vector WaveActiveBitOr(vector expr) +vector WaveActive$(opName.hlslName)(vector expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupOr($0)"; - case hlsl: __intrinsic_asm "WaveActiveBitOr"; + case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)"; + case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseOr $$vector result Subgroup Reduce $expr}; + return spirv_asm {OpGroupNonUniform$(opName.spirvName) $$vector result Subgroup Reduce $expr}; default: - return WaveMaskBitOr(WaveGetActiveMask(), expr); + return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } } __generic __target_intrinsic(hlsl) -matrix WaveActiveBitOr(matrix expr) -{ - return WaveMaskBitOr(WaveGetActiveMask(), expr); -} - -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) -T WaveActiveBitXor(T expr) +matrix WaveActive$(opName.hlslName)(matrix expr) { - __target_switch - { - case glsl: __intrinsic_asm "subgroupXor($0)"; - case hlsl: __intrinsic_asm "WaveActiveBitXor"; - case spirv: - return spirv_asm {OpGroupNonUniformBitwiseXor $$T result Subgroup Reduce $expr}; - default: - return WaveMaskBitXor(WaveGetActiveMask(), expr); - } -} - -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) -vector WaveActiveBitXor(vector expr) -{ - __target_switch - { - case glsl: __intrinsic_asm "subgroupXor($0)"; - case hlsl: __intrinsic_asm "WaveActiveBitXor"; - case spirv: - return spirv_asm {OpGroupNonUniformBitwiseXor $$vector result Subgroup Reduce $expr}; - default: - return WaveMaskBitXor(WaveGetActiveMask(), expr); - } + return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } +${{{{ +} // WaveActiveBitAnd, WaveActiveBitOr, WaveActiveBitXor +}}}} -__generic -__target_intrinsic(hlsl) -matrix WaveActiveBitXor(matrix expr) -{ - return WaveMaskBitXor(WaveGetActiveMask(), expr); -} +// WaveActiveMin/Max +${{{{ +const char* kWaveActiveMinMaxNames[] = {"Min", "Max"}; +for (const char* opName : kWaveActiveMinMaxNames) { +}}}} __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __spirv_capability(GroupNonUniformArithmetic) -T WaveActiveMax(T expr) +T WaveActive$(opName)(T expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupMax($0)"; - case hlsl: __intrinsic_asm "WaveActiveMax"; + case glsl: __intrinsic_asm "subgroup$(opName)($0)"; + case hlsl: __intrinsic_asm "WaveActive$(opName)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMax $$T result Subgroup Reduce $expr}; + return spirv_asm {OpGroupNonUniformF$(opName) $$T result Subgroup Reduce $expr}; else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformUMax $$T result Subgroup Reduce $expr}; + return spirv_asm {OpGroupNonUniformU$(opName) $$T result Subgroup Reduce $expr}; else - return spirv_asm {OpGroupNonUniformSMax $$T result Subgroup Reduce $expr}; + return spirv_asm {OpGroupNonUniformS$(opName) $$T result Subgroup Reduce $expr}; default: - return WaveMaskMax(WaveGetActiveMask(), expr); + return WaveMask$(opName)(WaveGetActiveMask(), expr); } } @@ -6270,135 +6203,126 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __spirv_capability(GroupNonUniformArithmetic) -vector WaveActiveMax(vector expr) +vector WaveActive$(opName)(vector expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupMax($0)"; - case hlsl: __intrinsic_asm "WaveActiveMax"; + case glsl: __intrinsic_asm "subgroup$(opName)($0)"; + case hlsl: __intrinsic_asm "WaveActive$(opName)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMax $$vector result Subgroup Reduce $expr}; + return spirv_asm {OpGroupNonUniformF$(opName) $$vector result Subgroup Reduce $expr}; else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformUMax $$vector result Subgroup Reduce $expr}; + return spirv_asm {OpGroupNonUniformU$(opName) $$vector result Subgroup Reduce $expr}; else - return spirv_asm {OpGroupNonUniformSMax $$vector result Subgroup Reduce $expr}; + return spirv_asm {OpGroupNonUniformS$(opName) $$vector result Subgroup Reduce $expr}; default: - return WaveMaskMax(WaveGetActiveMask(), expr); + return WaveMask$(opName)(WaveGetActiveMask(), expr); } } __generic __target_intrinsic(hlsl) -matrix WaveActiveMax(matrix expr) +matrix WaveActive$(opName)(matrix expr) { - return WaveMaskMax(WaveGetActiveMask(), expr); + return WaveMask$(opName)(WaveGetActiveMask(), expr); } +${{{{ +} // WaveActiveMinMax. +}}}} + +// WaveActiveProduct/Sum +${{{{ +struct WaveActiveProductSumEntry { const char* hlslName; const char* glslName; }; +const WaveActiveProductSumEntry kWaveActivProductSumNames[] = {{"Product", "Mul"}, {"Sum", "Add"}}; +for (auto opName : kWaveActivProductSumNames) { +}}}} + __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) -T WaveActiveMin(T expr) +T WaveActive$(opName.hlslName)(T expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupMin($0)"; - case hlsl: __intrinsic_asm "WaveActiveMin"; + case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)"; + case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMin $$T result Subgroup Reduce $expr}; + return spirv_asm { + OpCapability GroupNonUniformArithmetic; + OpGroupNonUniformF$(opName.glslName) $$T result Subgroup 0 $expr + }; + else if (__isSignedInt()) + { + return spirv_asm + { + OpCapability GroupNonUniformArithmetic; + // TODO: use the correct integer width + OpBitcast $$uint %uvalue $expr; + OpGroupNonUniformI$(opName.glslName) $$uint %mulResult Subgroup 0 %uvalue; + OpBitcast $$T result %mulResult + }; + } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformUMin $$T result Subgroup Reduce $expr}; - else - return spirv_asm {OpGroupNonUniformSMin $$T result Subgroup Reduce $expr}; + return spirv_asm + { + OpCapability GroupNonUniformArithmetic; + OpGroupNonUniformI$(opName.glslName) $$T result Subgroup 0 $expr + }; default: - return WaveMaskMin(WaveGetActiveMask(), expr); + return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) -vector WaveActiveMin(vector expr) +__target_intrinsic(hlsl) +vector WaveActive$(opName.hlslName)(vector expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupMin($0)"; - case hlsl: __intrinsic_asm "WaveActiveMin"; + case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)"; + case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMin $$vector result Subgroup Reduce $expr}; + return spirv_asm { + OpCapability GroupNonUniformArithmetic; + OpGroupNonUniformF$(opName.glslName) $$vector result Subgroup 0 $expr + }; + else if (__isSignedInt()) + { + return spirv_asm + { + OpCapability GroupNonUniformArithmetic; + // TODO: use the correct integer width + OpBitcast $$vector %uvalue $expr; + OpGroupNonUniformI$(opName.glslName) $$vector %$(opName.glslName)Result Subgroup 0 %uvalue; + OpBitcast $$vector result %$(opName.glslName)Result + }; + } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformUMin $$vector result Subgroup Reduce $expr}; - else - return spirv_asm {OpGroupNonUniformSMin $$vector result Subgroup Reduce $expr}; + return spirv_asm + { + OpCapability GroupNonUniformArithmetic; + OpGroupNonUniformI$(opName.glslName) $$vector result Subgroup 0 $expr + }; default: - return WaveMaskMin(WaveGetActiveMask(), expr); + return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } } __generic __target_intrinsic(hlsl) -matrix WaveActiveMin(matrix expr) +matrix WaveActive$(opName.hlslName)(matrix expr) { - return WaveMaskMin(WaveGetActiveMask(), expr); -} - -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupMul($0)") -__target_intrinsic(hlsl) -T WaveActiveProduct(T expr) -{ - return WaveMaskProduct(WaveGetActiveMask(), expr); -} - -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupMul($0)") -__target_intrinsic(hlsl) -vector WaveActiveProduct(vector expr) -{ - return WaveMaskProduct(WaveGetActiveMask(), expr); -} - -__generic -__target_intrinsic(hlsl) -matrix WaveActiveProduct(matrix expr) -{ - return WaveMaskProduct(WaveGetActiveMask(), expr); -} - -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAdd($0)") -__target_intrinsic(hlsl) -T WaveActiveSum(T expr) -{ - return WaveMaskSum(WaveGetActiveMask(), expr); -} - -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAdd($0)") -__target_intrinsic(hlsl) -vector WaveActiveSum(vector expr) -{ - return WaveMaskSum(WaveGetActiveMask(), expr); -} - -__generic -__target_intrinsic(hlsl) -matrix WaveActiveSum(matrix expr) -{ - return WaveMaskSum(WaveGetActiveMask(), expr); + return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } +${{{{ +} // WaveActiveProduct/WaveActiveProductSum. +}}}} __generic __glsl_extension(GL_KHR_shader_subgroup_vote) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index b890343fc..3bd52245a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -601,6 +601,9 @@ class AttributeBase : public Modifier AttributeDecl* attributeDecl = nullptr; + // The original identifier token representing the last part of the qualified name. + Token originalIdentifierToken; + List args; }; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 88d95f04e..75fd5177e 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2073,7 +2073,7 @@ namespace Slang // Replace the expression. This should make this situation easier to detect. expr->arguments[pp] = lValueImplicitCast; } - else + else if (!as(argExpr->type)) { getSink()->diagnose( argExpr, @@ -2102,11 +2102,10 @@ namespace Slang // Fall back, in case there are other reasons... diagnostic = &Diagnostics::implicitCastUsedAsLValue; } - getSink()->diagnoseWithoutSourceView( argExpr, *diagnostic, - implicitCastExpr->arguments[pp]->type, + implicitCastExpr->arguments[0]->type, implicitCastExpr->type); } diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 21d7669ce..53283fbe1 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -810,6 +810,7 @@ namespace Slang // First copy all of the state over from the original attribute. attr->keywordName = uncheckedAttr->keywordName; + attr->originalIdentifierToken = uncheckedAttr->originalIdentifierToken; attr->args = uncheckedAttr->args; attr->loc = uncheckedAttr->loc; attr->attributeDecl = attrDecl; diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 810315dd9..6c00312ba 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -684,8 +684,8 @@ bool _findAstNodeImpl(ASTLookupContext& context, SyntaxNode* node) if (attribute->getKeywordName() && _isLocInRange( &context, - attribute->getKeywordNameAndLoc().loc, - attribute->getKeywordName()->text.getLength())) + attribute->originalIdentifierToken.loc, + attribute->originalIdentifierToken.getContentLength())) { ASTLookupResult result; result.path = context.nodePath; diff --git a/source/slang/slang-language-server-semantic-tokens.cpp b/source/slang/slang-language-server-semantic-tokens.cpp index ab6d8b5ab..e85da9824 100644 --- a/source/slang/slang-language-server-semantic-tokens.cpp +++ b/source/slang/slang-language-server-semantic-tokens.cpp @@ -199,7 +199,8 @@ List getSemanticTokens(Linkage* linkage, Module* module, UnownedS if (attr->getKeywordName()) { SemanticToken token = _createSemanticToken( - manager, attr->getKeywordNameAndLoc().loc, attr->getKeywordName()); + manager, attr->originalIdentifierToken.loc, nullptr); + token.length = (int)attr->originalIdentifierToken.getContentLength(); token.type = SemanticTokenType::Type; maybeInsertToken(token); } diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index 5574b995d..54da4120c 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -714,6 +714,7 @@ SlangResult LanguageServer::hover( else if (auto attr = as(leafNode)) { fillDeclRefHoverInfo(makeDeclRef(attr->attributeDecl)); + hover.range.end.character = hover.range.start.character + (int)attr->originalIdentifierToken.getContentLength(); } if (sb.getLength() == 0) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 306d2cbec..e184585e3 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -878,7 +878,7 @@ namespace Slang // // '::'? identifier ('::' identifier)* - static Token parseAttributeName(Parser* parser) + static Token parseAttributeName(Parser* parser, Token& outOriginalLastToken) { const SourceLoc scopedIdSourceLoc = parser->tokenReader.peekLoc(); @@ -892,6 +892,7 @@ namespace Slang return parser->ReadToken(); const Token firstIdentifier = parser->ReadToken(TokenType::Identifier); + outOriginalLastToken = firstIdentifier; if (initialTokenType != TokenType::Scope && parser->tokenReader.peekTokenType() != TokenType::Scope) { return firstIdentifier; @@ -911,6 +912,7 @@ namespace Slang scopedIdentifierBuilder.append('_'); const Token nextIdentifier(parser->ReadToken(TokenType::Identifier)); + outOriginalLastToken = nextIdentifier; scopedIdentifierBuilder.append(nextIdentifier.getContent()); } @@ -946,12 +948,14 @@ namespace Slang // seems better to not complicate the parsing process any more. // - Token nameToken = parseAttributeName(parser); + Token originalLastToken; + Token nameToken = parseAttributeName(parser, originalLastToken); UncheckedAttribute* modifier = parser->astBuilder->create(); modifier->keywordName = nameToken.getName(); - modifier->loc = nameToken.getLoc(); + modifier->loc = originalLastToken.getLoc(); modifier->scope = parser->currentScope; + modifier->originalIdentifierToken = originalLastToken; if (AdvanceIf(parser, TokenType::LParent)) { -- cgit v1.2.3