diff options
Diffstat (limited to 'source/slang')
29 files changed, 217 insertions, 26 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 43640eb41..87697076e 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -679,6 +679,12 @@ struct Ref {}; __generic<T> +__magic_type(ConstRefType) +__intrinsic_type($(kIROp_ConstRefType)) +struct ConstRef +{}; + +__generic<T> __magic_type(OptionalType) __intrinsic_type($(kIROp_OptionalType)) struct Optional @@ -2237,6 +2243,9 @@ attribute_syntax [mutating] : MutatingAttribute; __attributeTarget(SetterDecl) attribute_syntax [nonmutating] : NonmutatingAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [constref] : ConstRefAttribute; + /// Indicates that a function computes its result as a function of its arguments without loading/storing any memory or other state. /// /// This is equivalent to the LLVM `readnone` function attribute. diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index b2d1e5c09..6a9aad257 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -311,6 +311,11 @@ RefType* ASTBuilder::getRefType(Type* valueType) return dynamicCast<RefType>(getPtrType(valueType, "RefType")); } +ConstRefType* ASTBuilder::getConstRefType(Type* valueType) +{ + return dynamicCast<ConstRefType>(getPtrType(valueType, "ConstRefType")); +} + OptionalType* ASTBuilder::getOptionalType(Type* valueType) { auto rsType = getSpecializedBuiltinType(valueType, "OptionalType"); diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index f75ab960f..a5e1cd40c 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -451,6 +451,9 @@ public: // Construct the type `Ref<valueType>` RefType* getRefType(Type* valueType); + // Construct the type `ConstRef<valueType>` + ConstRefType* getConstRefType(Type* valueType); + // Construct the type `Optional<valueType>` OptionalType* getOptionalType(Type* valueType); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index af5823db4..68824b931 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -161,6 +161,11 @@ class RefModifier : public Modifier SLANG_AST_CLASS(RefModifier) }; +// `__ref` modifier for by-reference parameter passing +class ConstRefModifier : public Modifier +{ + SLANG_AST_CLASS(ConstRefModifier) +}; // This is a special sentinel modifier that gets added // to the list when we have multiple variable declarations @@ -919,6 +924,14 @@ class NonmutatingAttribute : public Attribute SLANG_AST_CLASS(NonmutatingAttribute) }; +// A `[constref]` attribute, which indicates that the `this` parameter of +// a member function should be passed by reference. +// +class ConstRefAttribute : public Attribute +{ + SLANG_AST_CLASS(ConstRefAttribute) +}; + // A `[__readNone]` attribute, which indicates that a function // computes its results strictly based on argument values, without diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 0b4e9cab2..43b73892c 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1550,6 +1550,7 @@ namespace Slang kParameterDirection_Out, ///< Copy out kParameterDirection_InOut, ///< Copy in, copy out kParameterDirection_Ref, ///< By-reference + kParameterDirection_ConstRef, ///< By-const-reference }; /// The kind of a builtin interface requirement that can be automatically synthesized. diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index f80de86fd..a29ff9bb3 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -372,6 +372,10 @@ ParameterDirection FuncType::getParamDirection(Index index) { return kParameterDirection_Ref; } + else if (as<ConstRefType>(paramType)) + { + return kParameterDirection_ConstRef; + } else if (as<InOutType>(paramType)) { return kParameterDirection_InOut; diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 50d523cc5..3c50b1899 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -552,12 +552,23 @@ class InOutType : public OutTypeBase SLANG_AST_CLASS(InOutType) }; +class RefTypeBase : public ParamDirectionType +{ + SLANG_AST_CLASS(RefTypeBase) +}; + // The type for an `ref` parameter, e.g., `ref T` -class RefType : public ParamDirectionType +class RefType : public RefTypeBase { SLANG_AST_CLASS(RefType) }; +// The type for an `constref` parameter, e.g., `constref T` +class ConstRefType : public RefTypeBase +{ + SLANG_AST_CLASS(ConstRefType) +}; + class OptionalType : public BuiltinType { SLANG_AST_CLASS(OptionalType) diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index c4efba658..c2cfabcfe 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -991,11 +991,12 @@ namespace Slang return true; } - if (auto refType = as<RefType>(toType)) + if (auto refType = as<RefTypeBase>(toType)) { - if (!refType->getValueType()->equals(fromType)) + ConversionCost cost; + if (!canCoerce(refType->getValueType(), fromType, fromExpr, &cost)) return false; - if (!fromExpr->type.isLeftValue) + if (as<RefType>(toType) && !fromExpr->type.isLeftValue) return false; ConversionCost subCost = kConversionCost_GetRef; @@ -1016,7 +1017,7 @@ namespace Slang // Allow implicit dereferencing a reference type. - if (auto fromRefType = as<RefType>(fromType)) + if (auto fromRefType = as<RefTypeBase>(fromType)) { auto fromValueType = fromRefType->getValueType(); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 56b0a991b..351d5a9cc 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1843,6 +1843,14 @@ namespace Slang return false; } + if (satisfyingMemberDeclRef.getDecl()->hasModifier<ConstRefAttribute>() + && !requiredMemberDeclRef.getDecl()->hasModifier<ConstRefAttribute>()) + { + // A `[constref]` method can't satisfy a non-`[constref]` requirement, + // but vice-versa is okay. + return false; + } + if(satisfyingMemberDeclRef.getDecl()->hasModifier<HLSLStaticModifier>() != requiredMemberDeclRef.getDecl()->hasModifier<HLSLStaticModifier>()) { @@ -2710,7 +2718,14 @@ namespace Slang auto synMutatingAttr = m_astBuilder->create<MutatingAttribute>(); addModifier(synFuncDecl, synMutatingAttr); } - + if (requiredMemberDeclRef.getDecl()->hasModifier<ConstRefAttribute>()) + { + // If the interface requirement is `[constref]` then our + // synthesized method should be too. + // + auto synConstRefAttr = m_astBuilder->create<ConstRefAttribute>(); + addModifier(synFuncDecl, synConstRefAttr); + } if (requiredMemberDeclRef.getDecl()->hasModifier<NoDiffThisAttribute>()) { auto noDiffThisAttr = m_astBuilder->create<NoDiffThisAttribute>(); @@ -5161,6 +5176,11 @@ namespace Slang // if(fstParam.getDecl()->hasModifier<RefModifier>() != sndParam.getDecl()->hasModifier<RefModifier>()) return false; + + // If one parameter is `constref` and the other isn't, then they don't match. + // + if (fstParam.getDecl()->hasModifier<ConstRefModifier>() != sndParam.getDecl()->hasModifier<ConstRefModifier>()) + return false; } // Note(tfoley): return type doesn't enter into it, because we can't take @@ -5625,20 +5645,31 @@ namespace Slang // Remove all existing direction modifiers, and replace them with a single Ref modifier. List<Modifier*> newModifiers; bool hasRefModifier = false; + bool isMutable = false; for (auto modifier : paramDecl->modifiers) { - if (as<InModifier>(modifier) || as<InOutModifier>(modifier) || as<OutModifier>(modifier)) + if (as<InModifier>(modifier)) + { + continue; + } + else if (as<InOutModifier>(modifier) || as<OutModifier>(modifier)) { + isMutable = true; continue; } - if (as<RefModifier>(modifier)) + if (as<RefModifier>(modifier) || as<ConstRefModifier>(modifier)) { hasRefModifier = true; } newModifiers.add(modifier); } if (!hasRefModifier) - newModifiers.add(this->getASTBuilder()->create<RefModifier>()); + { + if (isMutable) + newModifiers.add(this->getASTBuilder()->create<RefModifier>()); + else + newModifiers.add(this->getASTBuilder()->create<ConstRefModifier>()); + } paramDecl->modifiers.first = newModifiers.getFirst(); for (Index i = 0; i < newModifiers.getCount(); i++) { @@ -5774,6 +5805,9 @@ namespace Slang case ParameterDirection::kParameterDirection_Ref: addModifier(param, m_astBuilder->create<RefModifier>()); break; + case ParameterDirection::kParameterDirection_ConstRef: + addModifier(param, m_astBuilder->create<ConstRefModifier>()); + break; default: break; } @@ -5879,7 +5913,9 @@ namespace Slang // specialization. for (auto paramDecl : decl->getParameters()) { - if (paramDecl->type.type && !isTypeDifferentiable(paramDecl->type.type)) + if (!paramDecl->type.type) + continue; + if (!isTypeDifferentiable(paramDecl->type.type)) { if (!paramDecl->hasModifier<NoDiffModifier>()) { @@ -5888,6 +5924,24 @@ namespace Slang addModifier(paramDecl, noDiffModifier); } } + if (!paramDecl->hasModifier<NoDiffModifier>()) + { + if (auto modifier = paramDecl->findModifier<ConstRefModifier>()) + { + getSink()->diagnose(modifier, Diagnostics::cannotUseConstRefOnDifferentiableParameter); + } + } + } + if (!isEffectivelyStatic(decl)) + { + auto constrefAttr = decl->findModifier<ConstRefAttribute>(); + if (constrefAttr) + { + if (isTypeDifferentiable(calcThisType(getParentDecl(decl)))) + { + getSink()->diagnose(constrefAttr, Diagnostics::cannotUseConstRefOnDifferentiableMemberMethod); + } + } } } } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 5ea897448..abdd89b01 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -214,11 +214,11 @@ namespace Slang { auto exprType = expr->type.type; - if (auto refType = as<RefType>(exprType)) + if (auto refType = as<RefTypeBase>(exprType)) { auto openRef = m_astBuilder->create<OpenRefExpr>(); openRef->innerExpr = expr; - openRef->type.isLeftValue = true; + openRef->type.isLeftValue = (as<RefType>(exprType) != nullptr); openRef->type.type = refType->getValueType(); return openRef; } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 183a9bc28..5c169a4ba 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -560,6 +560,8 @@ DIAGNOSTIC(38029, Error, typeArgumentDoesNotConformToInterface, "type argument ' DIAGNOSTIC(38031, Error, invalidUseOfNoDiff, "'no_diff' can only be used to decorate a call or a subscript operation") DIAGNOSTIC(38032, Error, useOfNoDiffOnDifferentiableFunc, "use 'no_diff' on a call to a differentiable function has no meaning.") DIAGNOSTIC(38033, Error, cannotUseNoDiffInNonDifferentiableFunc, "cannot use 'no_diff' in a non-differentiable function.") +DIAGNOSTIC(38034, Error, cannotUseConstRefOnDifferentiableParameter, "cannot use '__constref' on a differentiable parameter.") +DIAGNOSTIC(38034, Error, cannotUseConstRefOnDifferentiableMemberMethod, "cannot use '[constref]' on a differentiable member method of a differentiable type.") DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") DIAGNOSTIC(39999, Error, errorInImportedModule, "import of module '$0' failed because of a compilation error") diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index d26893987..2559269d4 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3235,6 +3235,10 @@ void CLikeSourceEmitter::emitParamTypeImpl(IRType* type, String const& name) m_writer->emit("inout "); type = refType->getValueType(); } + else if (auto constRefType = as<IRConstRefType>(type)) + { + type = constRefType->getValueType(); + } emitType(type, name); } diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 81bc2203a..ebe5965f4 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -403,8 +403,9 @@ void CPPSourceEmitter::useType(IRType* type) break; } case kIROp_RefType: + case kIROp_ConstRefType: { - type = static_cast<IRRefType*>(type)->getValueType(); + type = static_cast<IRPtrTypeBase*>(type)->getValueType(); break; } default: break; @@ -1039,6 +1040,7 @@ void CPPSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) } break; case kIROp_RefType: + case kIROp_ConstRefType: { auto ptrType = cast<IRPtrTypeBase>(type); PtrDeclaratorInfo refDeclarator(declarator); diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 6a223ae4f..7cfd9ffad 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1277,6 +1277,7 @@ struct SPIRVEmitContext } case kIROp_PtrType: case kIROp_RefType: + case kIROp_ConstRefType: case kIROp_OutType: case kIROp_InOutType: { diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index 4d44aac1f..6fb21e1ef 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -97,6 +97,12 @@ struct AddressInstEliminationContext auto addr = use->get(); auto call = as<IRCall>(use->getUser()); + // Don't change the use if addr is a non mutable address. + if (auto refType = as<IRConstRefType>(getRootAddr(addr)->getDataType())) + { + return; + } + IRBuilder builder(module); builder.setInsertBefore(call); auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType()); @@ -123,6 +129,8 @@ struct AddressInstEliminationContext { for (auto inst : block->getChildren()) { + if (as<IRConstRefType>(getRootAddr(inst)->getDataType())) + continue; if (auto ptrType = as<IRPtrTypeBase>(inst->getDataType())) { auto valType = unwrapAttributedType(ptrType->getValueType()); diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 61baa7dd7..f8c4cee66 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -1478,9 +1478,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( continue; IRBlock* defBlock = nullptr; - if (const auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType())) + if (auto varInst = as<IRVar>(instToStore)) { - auto varInst = as<IRVar>(instToStore); auto storeUse = findEarliestUniqueWriteUse(varInst); defBlock = getBlock(storeUse->getUser()); @@ -2126,6 +2125,10 @@ bool DefaultCheckpointPolicy::canRecompute(UseOrPseudoUse use) // We can't recompute a 'load' from a mutable function parameter. if (as<IRParam>(ptr) || as<IRVar>(ptr)) { + // An exception is a load of a constref parameter, which should + // remain constant throughout the function. + if (as<IRConstRefType>(getRootAddr(ptr)->getDataType())) + return true; if (isInstInPrimalOrTransposedParameterBlocks(ptr)) return false; } diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 335b6572e..ed70862d1 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -415,9 +415,9 @@ namespace Slang // Fetch primal values to use as arguments in primal func call. IRInst* primalArg = param; - if (!as<IROutType>(primalParamType)) + if (!as<IROutType>(primalParamType) && !as<IRConstRefType>(primalParamType)) { - // As long as the primal parameter is not an out type, + // As long as the primal parameter is not an out or constref type, // we need to fetch the primal value from the parameter. if (as<IRPtrTypeBase>(propagateParamType)) { @@ -428,7 +428,7 @@ namespace Slang primalArg = builder.emitDifferentialPairGetPrimal(primalArg); } } - if (auto primalParamPtrType = as<IRPtrTypeBase>(primalParamType)) + if (auto primalParamPtrType = isMutablePointerType(primalParamType)) { // If primal parameter is mutable, we need to pass in a temp var. auto tempVar = builder.emitVar(primalParamPtrType->getValueType()); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index a087e59d7..0406f224e 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -108,6 +108,7 @@ INST(Nop, nop, 0, 0) /* PtrTypeBase */ INST(PtrType, Ptr, 1, HOISTABLE) INST(RefType, Ref, 1, HOISTABLE) + INST(ConstRefType, ConstRef, 1, HOISTABLE) // A `PsuedoPtr<T>` logically represents a pointer to a value of type // `T` on a platform that cannot support pointers. The expectation // is that the "pointer" will be legalized away by storing a value diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 815966218..038681a58 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3246,6 +3246,7 @@ public: IROutType* getOutType(IRType* valueType); IRInOutType* getInOutType(IRType* valueType); IRRefType* getRefType(IRType* valueType); + IRConstRefType* getConstRefType(IRType* valueType); IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); IRPtrType* getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace); diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index c87eca587..79012e7ba 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -287,6 +287,7 @@ case kIROp_##TYPE##Type: \ case kIROp_OutType: case kIROp_InOutType: case kIROp_RefType: + case kIROp_ConstRefType: case kIROp_RawPointerType: case kIROp_PtrType: case kIROp_NativePtrType: diff --git a/source/slang/slang-ir-marshal-native-call.cpp b/source/slang/slang-ir-marshal-native-call.cpp index 3fcc02de0..401828daa 100644 --- a/source/slang/slang-ir-marshal-native-call.cpp +++ b/source/slang/slang-ir-marshal-native-call.cpp @@ -18,6 +18,7 @@ namespace Slang return builder.getNativePtrType((IRType*)as<IRComPtrType>(type)->getOperand(0)); case kIROp_InOutType: case kIROp_RefType: + case kIROp_ConstRefType: case kIROp_OutType: return builder.getPtrType(getNativeType(builder, (IRType*)type->getOperand(0))); default: @@ -76,6 +77,7 @@ namespace Slang { case kIROp_InOutType: case kIROp_RefType: + case kIROp_ConstRefType: case kIROp_OutType: return marshalRefManagedValueToNativeValue( builder, originalArg, args); @@ -135,6 +137,7 @@ namespace Slang { case kIROp_InOutType: case kIROp_RefType: + case kIROp_ConstRefType: SLANG_UNREACHABLE("out and ref types should be handled before reaching here."); break; case kIROp_StringType: diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 62da8cffd..b4a41f8a5 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -418,6 +418,7 @@ bool isPtrLikeOrHandleType(IRInst* type) case kIROp_InOutType: case kIROp_PtrType: case kIROp_RefType: + case kIROp_ConstRefType: return true; } return false; @@ -970,6 +971,17 @@ bool isOne(IRInst* inst) } } +IRPtrTypeBase* isMutablePointerType(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_ConstRefType: + return nullptr; + default: + return as<IRPtrTypeBase>(inst); + } +} + void initializeScratchData(IRInst* inst) { List<IRInst*> workList; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 8af4f7536..82ce1344c 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -221,6 +221,8 @@ bool isZero(IRInst* inst); bool isOne(IRInst* inst); +IRPtrTypeBase* isMutablePointerType(IRInst* inst); + void initializeScratchData(IRInst* inst); void resetScratchDataBit(IRInst* inst, int bitIndex); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 408b87e49..f8d2b3117 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2755,6 +2755,11 @@ namespace Slang return (IRRefType*) getPtrType(kIROp_RefType, valueType); } + IRConstRefType* IRBuilder::getConstRefType(IRType* valueType) + { + return (IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType); + } + IRSPIRVLiteralType* IRBuilder::getSPIRVLiteralType(IRType* type) { IRInst* operands[] = { type }; @@ -3588,6 +3593,7 @@ namespace Slang case kIROp_OutType: case kIROp_RawPointerType: case kIROp_RefType: + case kIROp_ConstRefType: case kIROp_ComPtrType: case kIROp_NativePtrType: case kIROp_NativeStringType: @@ -3697,6 +3703,7 @@ namespace Slang case kIROp_OutType: case kIROp_RawPointerType: case kIROp_RefType: + case kIROp_ConstRefType: return 3; case kIROp_VoidType: return 4; diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index a24e2e5c7..cb744d4da 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1637,6 +1637,7 @@ struct IRPtrTypeBase : IRType SIMPLE_IR_TYPE(PtrType, PtrTypeBase) SIMPLE_IR_TYPE(RefType, PtrTypeBase) +SIMPLE_IR_TYPE(ConstRefType, PtrTypeBase) SIMPLE_IR_PARENT_TYPE(OutTypeBase, PtrTypeBase) SIMPLE_IR_TYPE(OutType, OutTypeBase) SIMPLE_IR_TYPE(InOutType, OutTypeBase) diff --git a/source/slang/slang-language-server-inlay-hints.cpp b/source/slang/slang-language-server-inlay-hints.cpp index 801e28445..35b603b20 100644 --- a/source/slang/slang-language-server-inlay-hints.cpp +++ b/source/slang/slang-language-server-inlay-hints.cpp @@ -63,6 +63,7 @@ List<LanguageServerProtocol::InlayHint> getInlayHints( if (param->hasModifier<OutModifier>()) lblSb << "out "; else if (param->hasModifier<InOutModifier>()) lblSb << "inout "; else if (param->hasModifier<RefModifier>()) lblSb << "ref "; + else if (param->hasModifier<ConstRefModifier>()) lblSb << "constref "; lblSb << name->text; lblSb << ":"; hint.label = lblSb.produceString(); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ce577de6e..09992fb14 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2414,6 +2414,7 @@ void addArg( case kParameterDirection_Out: case kParameterDirection_InOut: + case kParameterDirection_ConstRef: { // According to our "calling convention" we need to // pass a pointer into the callee. @@ -2434,6 +2435,12 @@ void addArg( // If the value is not one that could yield a simple l-value // then we need to convert it into a temporary // + if (auto refType = as<IRConstRefType>(paramType)) + { + paramType = refType->getValueType(); + argVal = LoweredValInfo::simple(context->irBuilder->emitLoad(getSimpleVal(context, argPtr))); + } + LoweredValInfo tempVar = createVar(context, paramType); // If the parameter is `in out` or `inout`, then we need @@ -2441,7 +2448,8 @@ void addArg( // in the argument, which we accomplish by assigning // from the l-value to our temp. // - if (paramDirection == kParameterDirection_InOut) + if (paramDirection == kParameterDirection_InOut || + paramDirection == kParameterDirection_ConstRef) { assign(context, tempVar, argVal); } @@ -2455,11 +2463,14 @@ void addArg( // Finally, after the call we will need // to copy in the other direction: from our // temp back to the original l-value. - OutArgumentFixup fixup; - fixup.src = tempVar; - fixup.dst = argVal; + if (paramDirection != kParameterDirection_ConstRef) + { + OutArgumentFixup fixup; + fixup.src = tempVar; + fixup.dst = argVal; - (*ioFixups).add(fixup); + (*ioFixups).add(fixup); + } } } break; @@ -2492,6 +2503,7 @@ void addCallArgsForParam( switch(paramDirection) { case kParameterDirection_Ref: + case kParameterDirection_ConstRef: case kParameterDirection_Out: case kParameterDirection_InOut: { @@ -2526,6 +2538,10 @@ ParameterDirection getParameterDirection(VarDeclBase* paramDecl) return kParameterDirection_Ref; } + if (paramDecl->hasModifier<ConstRefModifier>()) + { + return kParameterDirection_ConstRef; + } if( paramDecl->hasModifier<InOutModifier>() ) { // The AST specified `inout`: @@ -2563,7 +2579,10 @@ ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection de if (parentParent->findModifier<NonCopyableTypeAttribute>()) { - return kParameterDirection_Ref; + if (parentDecl->hasModifier<MutatingAttribute>()) + return kParameterDirection_Ref; + else + return kParameterDirection_ConstRef; } // Applications can opt in to a mutable `this` parameter, @@ -2574,6 +2593,10 @@ ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection de { return kParameterDirection_InOut; } + else if (parentDecl->hasModifier<ConstRefAttribute>()) + { + return kParameterDirection_ConstRef; + } // A `set` accessor on a property or subscript declaration // defaults to a mutable `this` parameter, but the programmer @@ -2988,7 +3011,9 @@ void _lowerFuncDeclBaseTypeInfo( case kParameterDirection_Ref: irParamType = builder->getRefType(irParamType); break; - + case kParameterDirection_ConstRef: + irParamType = builder->getConstRefType(irParamType); + break; default: SLANG_UNEXPECTED("unknown parameter direction"); break; @@ -4030,7 +4055,18 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> { auto loweredBase = lowerLValueExpr(context, expr->base); - SLANG_ASSERT(loweredBase.flavor == LoweredValInfo::Flavor::Ptr); + if (loweredBase.flavor != LoweredValInfo::Flavor::Ptr) + { + SLANG_ASSERT(as<ConstRefType>(expr->type)); + // If the base isn't a pointer, then we are trying to form + // a const ref to a temporary value. + // To do so we must copy it into a variable. + auto baseVal = getSimpleVal(context, loweredBase); + auto tempVar = context->irBuilder->emitVar(baseVal->getFullType()); + context->irBuilder->emitStore(tempVar, baseVal); + loweredBase.val = tempVar; + } + loweredBase.flavor = LoweredValInfo::Flavor::Simple; return loweredBase; } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 0191b1e0c..3f8225084 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -7263,6 +7263,7 @@ namespace Slang _makeParseModifier("out", OutModifier::kReflectClassInfo), _makeParseModifier("inout", InOutModifier::kReflectClassInfo), _makeParseModifier("__ref", RefModifier::kReflectClassInfo), + _makeParseModifier("__constref", ConstRefModifier::kReflectClassInfo), _makeParseModifier("const", ConstModifier::kReflectClassInfo), _makeParseModifier("instance", InstanceModifier::kReflectClassInfo), _makeParseModifier("__builtin", BuiltinModifier::kReflectClassInfo), diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 1f96df2f0..8ed50510f 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -603,6 +603,10 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt { paramType = astBuilder->getRefType(paramType); } + else if (paramDecl->findModifier<ConstRefModifier>()) + { + paramType = astBuilder->getConstRefType(paramType); + } else if( paramDecl->findModifier<OutModifier>() ) { if(paramDecl->findModifier<InOutModifier>() || paramDecl->findModifier<InModifier>()) |
