From 051607368e8d3dd55d2ad2b2200ef656244ec70d Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 17 Feb 2023 13:23:27 -0800 Subject: Fixed crash when lowering IR for no_diff struct member. (#2658) * Fixed crash when lowering IR for no_diff struct member. * Improve `setInsertBeforeOrdinaryInst` and `setInsertAfterOrdinaryInst`. --------- Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 12 ++++++++++++ source/slang/slang-check-impl.h | 2 +- source/slang/slang-check-type.cpp | 9 +++++++++ source/slang/slang-ir-autodiff-unzip.h | 6 ++++-- source/slang/slang-lower-to-ir.cpp | 12 ++++++++++++ 5 files changed, 38 insertions(+), 3 deletions(-) (limited to 'source') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 837dcb8eb..381efa2c7 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1291,6 +1291,16 @@ namespace Slang { checkExtensionExternVarAttribute(varDecl, extensionExternAttr); } + + // If a var decl has no_diff type, move the no_diff modifier from the type to the var. + if (auto modifiedType = as(varDecl->type.type)) + { + if (auto nodiffModifier = modifiedType->findModifier()) + { + varDecl->type.type = getRemovedModifierType(modifiedType, nodiffModifier); + addModifier(varDecl, m_astBuilder->getOrCreate()); + } + } } void SemanticsDeclHeaderVisitor::visitStructDecl(StructDecl* structDecl) @@ -1527,6 +1537,8 @@ namespace Slang // Go through all var members. for (auto member : context->parentDecl->getMembersOfType()) { + if (member->hasModifier()) + continue; auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); if (!diffType) continue; diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 719706635..ccc739da3 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -508,7 +508,7 @@ namespace Slang Type* TranslateTypeNode(Expr* node); TypeExp TranslateTypeNodeForced(TypeExp const& typeExp); TypeExp TranslateTypeNode(TypeExp const& typeExp); - + Type* getRemovedModifierType(ModifiedType* type, ModifierVal* modifier); DeclRefType* getExprDeclRefType(Expr * expr); /// Is `decl` usable as a static member? diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index d402dde03..1b2179144 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -84,6 +84,15 @@ namespace Slang return TranslateTypeNodeForced(typeExp); } + Type* SemanticsVisitor::getRemovedModifierType(ModifiedType* modifiedType, ModifierVal* modifier) + { + if (modifiedType->modifiers.getCount() == 1) + return modifiedType->base; + auto newModifiers = modifiedType->modifiers; + newModifiers.remove(modifier); + return m_astBuilder->getModifiedType(modifiedType->base, newModifiers); + } + Expr* SemanticsVisitor::ExpectATypeRepr(Expr* expr) { if (auto overloadedExpr = as(expr)) diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index c3af52d8a..eb8b09417 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -503,7 +503,8 @@ struct DiffUnzipPass if (as(inst)) { SLANG_RELEASE_ASSERT(as(inst->getParent())); - builder->setInsertBefore(as(inst->getParent())->getFirstOrdinaryInst()); + auto lastParam = as(inst->getParent())->getLastParam(); + builder->setInsertAfter(lastParam); } else { @@ -516,7 +517,8 @@ struct DiffUnzipPass if (as(inst)) { SLANG_RELEASE_ASSERT(as(inst->getParent())); - builder->setInsertBefore(as(inst->getParent())->getFirstOrdinaryInst()); + auto lastParam = as(inst->getParent())->getLastParam(); + builder->setInsertAfter(lastParam); } else { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f78dd39e5..aa2dc4efb 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3320,6 +3320,7 @@ struct ExprLoweringVisitorBase : ExprVisitor LoweredValInfo getSimpleDefaultVal(IRType* type) { + type = (IRType*)unwrapAttributedType(type); if(auto basicType = as(type)) { switch( basicType->getBaseType() ) @@ -3355,8 +3356,18 @@ struct ExprLoweringVisitorBase : ExprVisitor UNREACHABLE_RETURN(LoweredValInfo()); } + Type* getOriginalTypeFromModifiedType(Type* type) + { + auto innerType = type; + while (auto modifiedType = as(innerType)) + innerType = modifiedType->base; + return innerType; + } + LoweredValInfo getDefaultVal(Type* type) { + type = getOriginalTypeFromModifiedType(type); + auto irType = lowerType(context, type); if (auto basicType = as(type)) { @@ -7909,6 +7920,7 @@ struct DeclLoweringVisitor : DeclVisitor bool isClassType(IRType* type) { + type = (IRType*)unwrapAttributedType(type); if (auto specialize = as(type)) { return findSpecializeReturnVal(specialize)->getOp() == kIROp_ClassType; -- cgit v1.2.3