From 17c7163c2ae8fc290e70b43d8700b68ef18b1ee1 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 6 Oct 2023 14:03:18 -0700 Subject: Small type system fixes. (#3265) --- source/slang/slang-check-expr.cpp | 24 +++++++++++++++++++++-- source/slang/slang-ir-check-differentiability.cpp | 5 +++++ source/slang/slang-ir-util.h | 12 +++++++++--- source/slang/slang-ir.cpp | 12 +++++++----- source/slang/slang-language-server-completion.cpp | 2 +- source/slang/slang-language-server.cpp | 22 +++++++++++++++++---- source/slang/slang-language-server.h | 1 + 7 files changed, 63 insertions(+), 15 deletions(-) (limited to 'source/slang') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 05a6ed249..22bc2cae8 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -372,7 +372,24 @@ namespace Slang { expr->type.isLeftValue = false; } - + else + { + // If we are accessing a readonly property, then the result + // is not an l-value. + if (auto propertyDecl = as(declRef.getDecl())) + { + bool isLValue = false; + for (auto member : propertyDecl->members) + { + if (as(member) || as< RefAccessorDecl>(member)) + { + isLValue = true; + break; + } + } + expr->type.isLeftValue = isLValue; + } + } return expr; } } @@ -3322,7 +3339,10 @@ namespace Slang // A swizzle can be used as an l-value as long as there // were no duplicates in the list of components - swizExpr->type.isLeftValue = !anyDuplicates; + swizExpr->type.isLeftValue = !anyDuplicates && + swizExpr->base && + swizExpr->base->type && + swizExpr->base->type.isLeftValue; return swizExpr; } diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index b937fe052..ddb70d779 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -85,7 +85,12 @@ public: switch (func->getOp()) { case kIROp_ForwardDifferentiate: + if (auto fwdDerivative = func->getOperand(0)->findDecoration()) + return isDifferentiableFunc(fwdDerivative->getForwardDerivativeFunc(), level); + return isDifferentiableFunc(func->getOperand(0), level); case kIROp_BackwardDifferentiate: + if (auto bwdDerivative = func->getOperand(0)->findDecoration()) + return isDifferentiableFunc(bwdDerivative->getBackwardDerivativeFunc(), level); return isDifferentiableFunc(func->getOperand(0), level); default: break; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index ff6298f39..0b377a3d1 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -152,9 +152,15 @@ inline bool isGenericParam(IRInst* param) inline IRInst* unwrapAttributedType(IRInst* type) { - while (auto attrType = as(type)) - type = attrType->getBaseType(); - return type; + for (;;) + { + if (auto attrType = as(type)) + type = attrType->getBaseType(); + else if (auto rateType = as(type)) + type = rateType->getValueType(); + else + return type; + } } // Remove hlsl's 'unorm' and 'snorm' modifiers diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6a3a26bd5..cf58e6cd4 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4848,19 +4848,20 @@ namespace Slang { IRType* type = nullptr; auto basePtrType = as(basePtr->getDataType()); - if (auto arrayType = as(basePtrType->getValueType())) + auto valueType = unwrapAttributedType(basePtrType->getValueType()); + if (auto arrayType = as(valueType)) { type = arrayType->getElementType(); } - else if (auto vectorType = as(basePtrType->getValueType())) + else if (auto vectorType = as(valueType)) { type = vectorType->getElementType(); } - else if (auto matrixType = as(basePtrType->getValueType())) + else if (auto matrixType = as(valueType)) { type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount()); } - else if (const auto basicType = as(basePtrType->getValueType())) + else if (const auto basicType = as(valueType)) { // HLSL support things like float.x, in which case we just return the base pointer. return basePtr; @@ -4884,10 +4885,11 @@ namespace Slang for (auto access : accessChain) { auto basePtrType = cast(basePtr->getDataType()); + auto valueType = unwrapAttributedType(basePtrType->getValueType()); IRType* resultType = nullptr; if (auto structKey = as(access)) { - auto structType = as(basePtrType->getValueType()); + auto structType = as(valueType); SLANG_RELEASE_ASSERT(structType); for (auto field : structType->getFields()) { diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp index ae1a0e9c2..58ee766cc 100644 --- a/source/slang/slang-language-server-completion.cpp +++ b/source/slang/slang-language-server-completion.cpp @@ -35,7 +35,7 @@ static const char* kStmtKeywords[] = { "__generic", "__exported", "import", "enum", "break", "continue", "discard", "defer", "cbuffer", "tbuffer", "func", "is", "as", "nullptr", "none", "true", "false", "functype", - "sizeof", "alignof"}; + "sizeof", "alignof", "__target_switch", "__intrinsic_asm"}; static const char* hlslSemanticNames[] = { "register", diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index 3a21fa278..cea5d0151 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -1469,9 +1469,10 @@ SlangResult LanguageServer::formatting(const LanguageServerProtocol::DocumentFor } if (m_formatOptions.clangFormatLocation.getLength() == 0) m_formatOptions.clangFormatLocation = findClangFormatTool(); - m_formatOptions.fileName = canonicalPath; + auto options = getFormatOptions(m_workspace, m_formatOptions); + options.fileName = canonicalPath; List exclusionRange = extractFormattingExclusionRanges(doc->getText().getUnownedSlice()); - auto edits = formatSource(doc->getText().getUnownedSlice(), -1, -1, -1, exclusionRange, m_formatOptions); + auto edits = formatSource(doc->getText().getUnownedSlice(), -1, -1, -1, exclusionRange, options); auto textEdits = translateTextEdits(doc, edits); m_connection->sendResult(&textEdits, responseId); return SLANG_OK; @@ -1491,7 +1492,7 @@ SlangResult LanguageServer::rangeFormatting(const LanguageServerProtocol::Docume Index endOffset = doc->getOffset(endLine, endCol); if (m_formatOptions.clangFormatLocation.getLength() == 0) m_formatOptions.clangFormatLocation = findClangFormatTool(); - auto options = m_formatOptions; + auto options = getFormatOptions(m_workspace, m_formatOptions); if (!m_formatOptions.allowLineBreakInRangeFormatting) options.behavior = FormatBehavior::PreserveLineBreak; List exclusionRange = extractFormattingExclusionRanges(doc->getText().getUnownedSlice()); @@ -1520,7 +1521,7 @@ SlangResult LanguageServer::onTypeFormatting(const LanguageServerProtocol::Docum Index line, col; doc->zeroBasedUTF16LocToOneBasedUTF8Loc(args.position.line, args.position.character, line, col); auto cursorOffset = doc->getOffset(line, col); - auto options = m_formatOptions; + auto options = getFormatOptions(m_workspace, m_formatOptions); if (!m_formatOptions.allowLineBreakInOnTypeFormatting) options.behavior = FormatBehavior::PreserveLineBreak; List exclusionRange = extractFormattingExclusionRanges(doc->getText().getUnownedSlice()); @@ -1760,6 +1761,19 @@ void LanguageServer::logMessage(int type, String message) m_connection->sendCall(LanguageServerProtocol::LogMessageParams::methodName, &args); } +FormatOptions LanguageServer::getFormatOptions(Workspace* workspace, FormatOptions inOptions) +{ + FormatOptions result = inOptions; + if (workspace->rootDirectories.getCount()) + { + result.clangFormatLocation = StringUtil::replaceAll( + result.clangFormatLocation.getUnownedSlice(), + toSlice("${workspaceFolder}"), + workspace->rootDirectories.getFirst().getUnownedSlice()); + } + return result; +} + SlangResult LanguageServer::tryGetMacroHoverInfo( WorkspaceVersion* version, DocumentVersion* doc, Index line, Index col, JSONValue responseId) { diff --git a/source/slang/slang-language-server.h b/source/slang/slang-language-server.h index 175ccf0f1..b17991714 100644 --- a/source/slang/slang-language-server.h +++ b/source/slang/slang-language-server.h @@ -152,6 +152,7 @@ private: void registerCapability(const char* methodName); void logMessage(int type, String message); + FormatOptions getFormatOptions(Workspace* workspace, FormatOptions inOptions); SlangResult tryGetMacroHoverInfo( WorkspaceVersion* version, DocumentVersion* doc, -- cgit v1.2.3