diff options
| -rw-r--r-- | source/slang/core.meta.slang | 93 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 52 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 59 | ||||
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 29 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-constexpr.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 62 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 29 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 29 |
20 files changed, 349 insertions, 94 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index b7f50cd1b..267f7b2d4 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1,4 +1,4 @@ -//public module core; +public module core; // Slang `core` library @@ -2367,49 +2367,6 @@ __generic<T> __extension vector<T, 4> ${{{{ -// The above extensions are generic in the *type* of the vector, -// but explicit in the *size*. We will now declare an extension -// for each builtin type that is generic in the size. -// -for (int tt = 0; tt < kBaseTypeCount; ++tt) -{ - if(kBaseTypes[tt].tag == BaseType::Void) continue; - - sb << "__generic<let N : int> __extension vector<" - << kBaseTypes[tt].name << ",N>\n{\n"; - - for (int ff = 0; ff < kBaseTypeCount; ++ff) - { - if(kBaseTypes[ff].tag == BaseType::Void) continue; - - - if( tt != ff ) - { - auto cost = getBaseTypeConversionCost( - kBaseTypes[tt], - kBaseTypes[ff]); - auto op = getBaseTypeConversionOp( - kBaseTypes[tt], - kBaseTypes[ff]); - - // Implicit conversion from a vector of the same - // size, but different element type. - sb << " __implicit_conversion(" << cost << ")\n"; - sb << " __intrinsic_op(" << int(op) << ")\n"; - sb << " __init(vector<" << kBaseTypes[ff].name << ",N> value);\n"; - - // Constructor to make a vector from a scalar of another type. - if (cost != kConversionCost_Impossible) - { - cost += kConversionCost_ScalarToVector; - sb << " __implicit_conversion(" << cost << ")\n"; - sb << " [__unsafeForceInlineEarly]\n"; - sb << " __init(" << kBaseTypes[ff].name << " value) { this = vector<" << kBaseTypes[tt].name << ",N>( " << kBaseTypes[tt].name << "(value)); }\n"; - } - } - } - sb << "}\n"; -} for( int R = 1; R <= 4; ++R ) for( int C = 1; C <= 4; ++C ) @@ -2464,38 +2421,36 @@ for( int C = 1; C <= 4; ++C ) sb << "}\n"; } -for (int tt = 0; tt < kBaseTypeCount; ++tt) -{ - if(kBaseTypes[tt].tag == BaseType::Void) continue; - auto toType = kBaseTypes[tt].name; }}}} -__generic<let R : int, let C : int, let L : int> extension matrix<$(toType),R,C,L> +//@hidden: +__intrinsic_op($(kIROp_BuiltinCast)) +internal T __builtin_cast<T, U>(U u); + +// If T is implicitly convertible to U, then vector<T,N> is implicitly convertible to vector<U,N>. +__generic<ToType, let N : int> extension vector<ToType,N> { -${{{{ - for (int ff = 0; ff < kBaseTypeCount; ++ff) - { - if(kBaseTypes[ff].tag == BaseType::Void) continue; - if( tt == ff ) continue; + __implicit_conversion(constraint) + __intrinsic_op(BuiltinCast) + __init<FromType>(vector<FromType,N> value) where ToType(FromType) implicit; - auto cost = getBaseTypeConversionCost( - kBaseTypes[tt], - kBaseTypes[ff]); - auto fromType = kBaseTypes[ff].name; - auto op = getBaseTypeConversionOp( - kBaseTypes[tt], - kBaseTypes[ff]); -}}}} - __implicit_conversion($(cost)) - __intrinsic_op($(op)) - __init(matrix<$(fromType),R,C,L> value); -${{{{ + __implicit_conversion(constraint+) + [__unsafeForceInlineEarly] + [__readNone] + [TreatAsDifferentiable] + __init<FromType>(FromType value) where ToType(FromType) implicit + { + this = __builtin_cast<vector<ToType,N>>(vector<FromType,N>(value)); } -}}}} } -${{{{ + +// If T is implicitly convertible to U, then matrix<T,R,C,L> is implicitly convertible to matrix<U,R,C,L>. +__generic<ToType, let R : int, let C : int, let L : int> extension matrix<ToType,R,C,L> +{ + __implicit_conversion(constraint) + __intrinsic_op(BuiltinCast) + __init<FromType>(matrix<FromType,R,C,L> value) where ToType(FromType) implicit; } -}}}} //@ hidden: __generic<T, U> diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 6ffaee7db..b3afa5310 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -948,6 +948,14 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness( return witness; } +TypeCoercionWitness* ASTBuilder::getTypeCoercionWitness( + Type* subType, + Type* superType, + DeclRef<Decl> declRef) +{ + return getOrCreate<TypeCoercionWitness>(subType, superType, declRef.declRefBase); +} + DeclRef<Decl> _getMemberDeclRef(ASTBuilder* builder, DeclRef<Decl> parent, Decl* decl) { return builder->getMemberDeclRef(parent, decl); diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index cae380e40..67dfaaf52 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -636,6 +636,11 @@ public: SubtypeWitness* subIsLWitness, SubtypeWitness* subIsRWitness); + TypeCoercionWitness* getTypeCoercionWitness( + Type* fromType, + Type* toType, + DeclRef<Decl> declRef); + /// Helpers to get type info from the SharedASTBuilder const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice) { diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index ff8e5684a..ff55340ac 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -612,6 +612,15 @@ class GenericTypeConstraintDecl : public TypeConstraintDecl const TypeExp& _getSupOverride() const { return sup; } }; +class TypeCoercionConstraintDecl : public Decl +{ + SLANG_AST_CLASS(TypeCoercionConstraintDecl) + + SourceLoc whereTokenLoc = SourceLoc(); + TypeExp fromType; + TypeExp toType; +}; + class GenericValueParamDecl : public VarDeclBase { SLANG_AST_CLASS(GenericValueParamDecl) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index cc4901236..e4d5ccd09 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1279,10 +1279,10 @@ class ImplicitConversionModifier : public Modifier SLANG_AST_CLASS(ImplicitConversionModifier) // The conversion cost, used to rank conversions - ConversionCost cost; + ConversionCost cost = kConversionCost_None; // A builtin identifier for identifying conversions that need special treatment. - BuiltinConversionKind builtinConversionKind; + BuiltinConversionKind builtinConversionKind = kBuiltinConversion_Unknown; }; class FormatAttribute : public Attribute diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index d24007721..b3baee98f 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -178,6 +178,11 @@ enum : ConversionCost // Additional cost when casting an LValue. kConversionCost_LValueCast = 800, + // The cost of this conversion is defined by the type coercion constraint. + kConversionCost_TypeCoercionConstraint = 1000, + kConversionCost_TypeCoercionConstraintPlusScalarToVector = + kConversionCost_TypeCoercionConstraint + kConversionCost_ScalarToVector, + // Conversion is impossible kConversionCost_Impossible = 0xFFFFFFFF, }; diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 9bcfd21bc..7613dbe80 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -845,6 +845,58 @@ void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) out << ")"; } +void TypeCoercionWitness::_toTextOverride(StringBuilder& out) +{ + out << "TypeCoercionWitness("; + if (getFromType()) + out << getFromType(); + if (getToType()) + out << getToType(); + out << ")"; +} + +Val* TypeCoercionWitness::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) +{ + int diff = 0; + + auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff); + auto substFrom = as<Type>(getFromType()->substituteImpl(astBuilder, subst, &diff)); + auto substTo = as<Type>(getToType()->substituteImpl(astBuilder, subst, &diff)); + + if (!diff) + return this; + + (*ioDiff)++; + + TypeCoercionWitness* substValue = + astBuilder->getTypeCoercionWitness(substFrom, substTo, substDeclRef); + return substValue; +} + +Val* TypeCoercionWitness::_resolveImplOverride() +{ + Val* resolvedDeclRef = nullptr; + if (getDeclRef()) + resolvedDeclRef = getDeclRef().declRefBase->resolve(); + if (auto resolvedVal = as<Witness>(resolvedDeclRef)) + return resolvedVal; + + auto newFrom = as<Type>(getFromType()->resolve()); + auto newTo = as<Type>(getToType()->resolve()); + + auto newDeclRef = as<DeclRefBase>(resolvedDeclRef); + if (!newDeclRef) + newDeclRef = getDeclRef().declRefBase; + if (newFrom != getFromType() || newTo != getToType() || newDeclRef != getDeclRef()) + { + return getCurrentASTBuilder()->getTypeCoercionWitness(newFrom, newTo, newDeclRef); + } + return this; +} + // UNormModifierVal void UNormModifierVal::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 7b33a8111..3a14be17b 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -621,6 +621,20 @@ class TypeEqualityWitness : public SubtypeWitness Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +class TypeCoercionWitness : public Witness +{ + SLANG_AST_CLASS(TypeCoercionWitness) + + Type* getFromType() { return as<Type>(getOperand(0)); } + Type* getToType() { return as<Type>(getOperand(1)); } + + DeclRef<Decl> getDeclRef() { return as<DeclRefBase>(getOperand(2)); } + + void _toTextOverride(StringBuilder& out); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride(); +}; + // A witness that one type is a subtype of another // because some in-scope declaration says so class DeclaredSubtypeWitness : public SubtypeWitness diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 872d2616c..642a4bf6a 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -715,13 +715,6 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem( // system as being solved now, as a result of the witness we found. } - // Add a flat cost to all unconstrained generic params. - for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>()) - { - if (!constrainedGenericParams.contains(typeParamDecl)) - outBaseCost += kConversionCost_UnconstraintGenericParam; - } - // Make sure we haven't constructed any spurious constraints // that we aren't able to satisfy: for (auto c : system->constraints) @@ -732,6 +725,58 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem( } } + // Verify that all type coercion constraints can be satisfied. + for (auto constraintDecl : + genericDeclRef.getDecl()->getMembersOfType<TypeCoercionConstraintDecl>()) + { + DeclRef<TypeCoercionConstraintDecl> constraintDeclRef = + m_astBuilder + ->getGenericAppDeclRef( + genericDeclRef, + args.getArrayView().arrayView, + constraintDecl) + .as<TypeCoercionConstraintDecl>(); + auto fromType = constraintDeclRef.substitute(m_astBuilder, constraintDecl->fromType.Ptr()); + auto toType = constraintDeclRef.substitute(m_astBuilder, constraintDecl->toType.Ptr()); + auto conversionCost = getConversionCost(toType, fromType); + if (constraintDecl->findModifier<ImplicitConversionModifier>()) + { + if (conversionCost > kConversionCost_GeneralConversion) + { + // The type arguments are not implicitly convertible, return failure. + return DeclRef<Decl>(); + } + } + else + { + if (conversionCost == kConversionCost_Impossible) + { + // The type arguments are not convertible, return failure. + return DeclRef<Decl>(); + } + } + if (auto fromDecl = isDeclRefTypeOf<Decl>(constraintDecl->fromType)) + { + constrainedGenericParams.add(fromDecl.getDecl()); + } + if (auto toDecl = isDeclRefTypeOf<Decl>(constraintDecl->toType)) + { + constrainedGenericParams.add(toDecl.getDecl()); + } + // If we are to expand the support of type coercion constraint beyond simple builtin core + // module functions, then the witness should be a reference to the conversion function. For + // now, this isn't required, and it is not easy to get it from the coercion logic, so we + // leave it empty. + args.add(m_astBuilder->getTypeCoercionWitness(fromType, toType, DeclRef<Decl>())); + } + + // Add a flat cost to all unconstrained generic params. + for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>()) + { + if (!constrainedGenericParams.contains(typeParamDecl)) + outBaseCost += kConversionCost_UnconstraintGenericParam; + } + return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView().arrayView); } diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 6dda9c1ea..a9785a585 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1045,11 +1045,28 @@ int getTypeBitSize(Type* t) } ConversionCost SemanticsVisitor::getImplicitConversionCostWithKnownArg( - Decl* decl, + DeclRef<Decl> decl, Type* toType, Expr* arg) { - ConversionCost candidateCost = getImplicitConversionCost(decl); + ConversionCost candidateCost = getImplicitConversionCost(decl.getDecl()); + + if (candidateCost == kConversionCost_TypeCoercionConstraint || + candidateCost == kConversionCost_TypeCoercionConstraintPlusScalarToVector) + { + if (auto genApp = as<GenericAppDeclRef>(decl.declRefBase)) + { + for (auto genArg : genApp->getArgs()) + { + if (auto wit = as<TypeCoercionWitness>(genArg)) + { + candidateCost -= kConversionCost_TypeCoercionConstraint; + candidateCost += getConversionCost(wit->getToType(), wit->getFromType()); + break; + } + } + } + } // Fix up the cost if the operand is a const lit. if (isScalarIntegerType(toType)) @@ -1577,10 +1594,8 @@ bool SemanticsVisitor::_coerce( ImplicitCastMethod method; for (auto candidate : overloadContext.bestCandidates) { - ConversionCost candidateCost = getImplicitConversionCostWithKnownArg( - candidate.item.declRef.getDecl(), - toType, - fromExpr); + ConversionCost candidateCost = + getImplicitConversionCostWithKnownArg(candidate.item.declRef, toType, fromExpr); if (candidateCost < bestCost) { method.conversionFuncOverloadCandidate = candidate; @@ -1632,7 +1647,7 @@ bool SemanticsVisitor::_coerce( // cost associated with the initializer we are invoking. // ConversionCost cost = getImplicitConversionCostWithKnownArg( - overloadContext.bestCandidate->item.declRef.getDecl(), + overloadContext.bestCandidate->item.declRef, toType, fromExpr); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 1ef5b1cec..5b5e05b73 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -147,6 +147,8 @@ struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase, void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl); + void visitTypeCoercionConstraintDecl(TypeCoercionConstraintDecl* decl); + void validateGenericConstraintSubType(GenericTypeConstraintDecl* decl, TypeExp type); void visitGenericDecl(GenericDecl* genericDecl); @@ -2911,6 +2913,16 @@ void SemanticsDeclHeaderVisitor::validateGenericConstraintSubType( } } +void SemanticsDeclHeaderVisitor::visitTypeCoercionConstraintDecl(TypeCoercionConstraintDecl* decl) +{ + CheckConstraintSubType(decl->toType); + + if (!decl->fromType.type) + decl->fromType = TranslateTypeNodeForced(decl->fromType); + if (!decl->toType.type) + decl->toType = TranslateTypeNodeForced(decl->toType); +} + void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) { // TODO: are there any other validations we can do at this point? diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 59290f8ad..6438a91e3 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1511,7 +1511,10 @@ public: // perform implicit type conversion. ConversionCost getImplicitConversionCost(Decl* decl); - ConversionCost getImplicitConversionCostWithKnownArg(Decl* decl, Type* toType, Expr* arg); + ConversionCost getImplicitConversionCostWithKnownArg( + DeclRef<Decl> decl, + Type* toType, + Expr* arg); BuiltinConversionKind getImplicitConversionBuiltinKind(Decl* decl); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index b944d2bf4..b75f95f9a 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1675,6 +1675,15 @@ int SemanticsVisitor::CompareOverloadCandidates(OverloadCandidate* left, Overloa if (itemDiff) return itemDiff; + // If one candidate is an implicit conversion, and other candidate is not, + // then we should prefer the implicit conversion. + int leftIsImplicitConversion = + left->item.declRef.getDecl()->findModifier<ImplicitConversionModifier>() ? 1 : 0; + int rightIsImplicitConversion = + right->item.declRef.getDecl()->findModifier<ImplicitConversionModifier>() ? 1 : 0; + if (leftIsImplicitConversion != rightIsImplicitConversion) + return rightIsImplicitConversion - leftIsImplicitConversion; + auto specificityDiff = compareOverloadCandidateSpecificity(left->item, right->item); if (specificityDiff) return specificityDiff; diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index ff6d64319..620c65d4e 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -116,6 +116,7 @@ bool opCanBeConstExpr(IROp op) case kIROp_PtrCast: case kIROp_Reinterpret: case kIROp_BitCast: + case kIROp_BuiltinCast: case kIROp_MakeTuple: case kIROp_MakeDifferentialPair: case kIROp_MakeExistential: @@ -178,7 +179,13 @@ bool opCanBeConstExprByBackwardPass(IRInst* value) { if (value->getOp() == kIROp_Param) return isLoopPhi(as<IRParam, IRDynamicCastBehavior::NoUnwrap>(value)); - return opCanBeConstExpr(value->getOp()); + if (opCanBeConstExpr(value->getOp())) + return true; + if (auto callInst = as<IRCall>(value)) + { + return !callInst->mightHaveSideEffects(); + } + return false; } void markConstExpr(PropagateConstExprContext* context, IRInst* value) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 55880eab5..5a1966d00 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1202,6 +1202,7 @@ INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, HOIST INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0) INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) +INST(BuiltinCast, BuiltinCast, 1, 0) INST(BitCast, bitCast, 1, 0) INST(Reinterpret, reinterpret, 1, 0) INST(Unmodified, unmodified, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index dbefa68c7..d64820aa6 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4025,7 +4025,7 @@ public: /// the inst. IRInst* emitDefaultConstructRaw(IRType* type); - IRInst* emitCast(IRType* type, IRInst* value); + IRInst* emitCast(IRType* type, IRInst* value, bool fallbackToBuiltinCast = true); IRInst* emitVectorReshape(IRType* type, IRInst* value); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index fc399954b..e29fdf975 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -98,6 +98,7 @@ struct PeepholeContext : InstPassBase else if (remainingKeys.getCount() > 0) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); auto newValue = builder.emitElementExtract(updateInst->getElementValue(), remainingKeys); @@ -112,6 +113,7 @@ struct PeepholeContext : InstPassBase // accessChain!=accessChain2, then we can replace the inst with extract(x, // accessChain2). IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); auto newInst = builder.emitElementExtract(updateInst->getOldValue(), chainKey.getArrayView()); @@ -140,6 +142,8 @@ struct PeepholeContext : InstPassBase if (vectorType->getElementType() != replacement->getFullType()) return false; IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); replacement = builder.emitMakeVectorFromScalar(inst->getFullType(), replacement); @@ -175,6 +179,7 @@ struct PeepholeContext : InstPassBase else if (inst->getOperand(0) == inst->getOperand(1)) { IRBuilder builder(inst); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); return tryReplace(builder.emitDefaultConstruct(inst->getDataType())); } @@ -280,6 +285,8 @@ struct PeepholeContext : InstPassBase break; IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); IRInst* resultVal = nullptr; if (inst->getOp() == kIROp_AlignOf) @@ -319,6 +326,8 @@ struct PeepholeContext : InstPassBase if (inst->getOperand(0)->getOp() == kIROp_MakeResultError) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + inst->replaceUsesWith(builder.getBoolValue(true)); maybeRemoveOldInst(inst); changed = true; @@ -326,6 +335,8 @@ struct PeepholeContext : InstPassBase else if (inst->getOperand(0)->getOp() == kIROp_MakeResultValue) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + inst->replaceUsesWith(builder.getBoolValue(false)); maybeRemoveOldInst(inst); changed = true; @@ -359,6 +370,8 @@ struct PeepholeContext : InstPassBase if (const auto packType = as<IRTypePack>(pack->getDataType())) { IRBuilder builder(inst); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); List<IRInst*> args; for (UInt j = 0; j < packType->getOperandCount(); ++j) @@ -443,6 +456,8 @@ struct PeepholeContext : InstPassBase index->getValue() < startIndex + vecSize->getValue()) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto newElement = builder.emitElementExtract( element, @@ -517,6 +532,8 @@ struct PeepholeContext : InstPassBase if (args.getCount() == arraySize->getValue()) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto makeArray = builder.emitMakeArray( arrayType, @@ -573,6 +590,8 @@ struct PeepholeContext : InstPassBase if (isComplete) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto makeArray = builder.emitMakeArray( arrayType, @@ -618,6 +637,8 @@ struct PeepholeContext : InstPassBase if (isValid) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto makeStruct = builder.emitMakeStruct( structType, @@ -678,6 +699,8 @@ struct PeepholeContext : InstPassBase // Create a makeStruct inst using args. IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto makeStruct = builder.emitMakeStruct( structType, @@ -694,6 +717,8 @@ struct PeepholeContext : InstPassBase { auto ptr = inst->getOperand(0); IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto neq = builder.emitNeq(ptr, builder.getNullPtrValue(ptr->getDataType())); inst->replaceUsesWith(neq); @@ -708,6 +733,8 @@ struct PeepholeContext : InstPassBase if (isTypeEqual(actualType, (IRType*)isTypeInst->getTypeOperand())) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto trueVal = builder.getBoolValue(true); inst->replaceUsesWith(trueVal); @@ -770,6 +797,7 @@ struct PeepholeContext : InstPassBase if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); auto trueVal = builder.getBoolValue(true); inst->replaceUsesWith(trueVal); @@ -779,6 +807,8 @@ struct PeepholeContext : InstPassBase else if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalNone) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto falseVal = builder.getBoolValue(false); inst->replaceUsesWith(falseVal); @@ -841,6 +871,7 @@ struct PeepholeContext : InstPassBase case kIROp_DefaultConstruct: { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); // See if we can replace the default construct inst with concrete values. if (auto newCtor = builder.emitDefaultConstruct(inst->getFullType(), false)) @@ -851,6 +882,21 @@ struct PeepholeContext : InstPassBase } } break; + case kIROp_BuiltinCast: + { + IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); + // See if we can replace the default construct inst with concrete values. + if (auto newCast = + builder.emitCast(inst->getFullType(), inst->getOperand(0), false)) + { + inst->replaceUsesWith(newCast); + maybeRemoveOldInst(inst); + changed = true; + } + } + break; case kIROp_VectorReshape: { auto fromType = as<IRVectorType>(inst->getOperand(0)->getDataType()); @@ -867,6 +913,7 @@ struct PeepholeContext : InstPassBase break; } IRBuilder builder(inst); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); UInt index = 0; auto newInst = builder.emitSwizzle(resultType, inst->getOperand(0), 1, &index); @@ -882,6 +929,8 @@ struct PeepholeContext : InstPassBase if (!toCount) break; IRBuilder builder(inst); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto newInst = builder.emitVectorReshape(resultType, inst->getOperand(0)); if (newInst != inst) @@ -911,6 +960,7 @@ struct PeepholeContext : InstPassBase break; List<IRInst*> rows; IRBuilder builder(inst); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); builder.setInsertBefore(inst); auto toRowType = builder.getVectorType( resultType->getElementType(), @@ -1035,6 +1085,8 @@ struct PeepholeContext : InstPassBase break; } IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto newInst = builder.emitMakeVectorFromScalar(vectorType, inst->getOperand(0)); @@ -1075,6 +1127,8 @@ struct PeepholeContext : InstPassBase else { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto newMakeVector = builder.emitMakeVector( swizzle->getDataType(), @@ -1100,6 +1154,8 @@ struct PeepholeContext : InstPassBase if (isConcreteType(left) && isConcreteType(right)) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); bool result = left == right; inst->replaceUsesWith(builder.getBoolValue(result)); @@ -1123,6 +1179,8 @@ struct PeepholeContext : InstPassBase if (!SLANG_SUCCEEDED(res)) break; IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto stride = builder.getIntValue(inst->getDataType(), sizeAlignment.getStride()); @@ -1148,6 +1206,8 @@ struct PeepholeContext : InstPassBase if (isConcreteType(type)) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); bool result = false; switch (inst->getOp()) @@ -1186,6 +1246,8 @@ struct PeepholeContext : InstPassBase if (as<IRLoad>(inst)->getPtr()->getOp() == kIROp_undefined) { IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); auto undef = builder.emitUndefined(inst->getDataType()); inst->replaceUsesWith(undef); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index cdabb1ac2..f28f61ffc 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3995,7 +3995,7 @@ static TypeCastStyle _getTypeStyleId(IRType* type) } } -IRInst* IRBuilder::emitCast(IRType* type, IRInst* value) +IRInst* IRBuilder::emitCast(IRType* type, IRInst* value, bool fallbackToBuiltinCast) { if (isTypeEqual(type, value->getDataType())) return value; @@ -4009,8 +4009,17 @@ IRInst* IRBuilder::emitCast(IRType* type, IRInst* value) SLANG_UNREACHABLE("cast from void type"); } - SLANG_RELEASE_ASSERT(toStyle != TypeCastStyle::Unknown); - SLANG_RELEASE_ASSERT(fromStyle != TypeCastStyle::Unknown); + if (toStyle == TypeCastStyle::Unknown || fromStyle == TypeCastStyle::Unknown) + { + if (fallbackToBuiltinCast) + { + return emitIntrinsicInst(type, kIROp_BuiltinCast, 1, &value); + } + else + { + return nullptr; + } + } struct OpSeq { @@ -4057,7 +4066,18 @@ IRInst* IRBuilder::emitCast(IRType* type, IRInst* value) auto t = type; if (op.op1 != kIROp_Nop) { - t = getUInt64Type(); + if (toStyle == TypeCastStyle::Bool) + t = getIntType(); + else + t = getUInt64Type(); + if (auto vecType = as<IRVectorType>(type)) + t = getVectorType(t, vecType->getElementCount()); + else if (auto matType = as<IRMatrixType>(type)) + t = getMatrixType( + t, + matType->getRowCount(), + matType->getColumnCount(), + matType->getLayout()); } auto result = emitIntrinsicInst(t, op.op0, 1, &value); if (op.op1 != kIROp_Nop) @@ -8293,6 +8313,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_ExtractExistentialValue: case kIROp_ExtractExistentialWitnessTable: case kIROp_WrapExistential: + case kIROp_BuiltinCast: case kIROp_BitCast: case kIROp_CastFloatToInt: case kIROp_CastIntToFloat: diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ed8a52b9e..fbe6d8a84 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1737,6 +1737,13 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower context->irBuilder->getTypeEqualityWitness(witnessType, subType, supType)); } + LoweredValInfo visitTypeCoercionWitness(TypeCoercionWitness*) + { + // When we fully support type coercion constraints, we should lower the witness into a + // function that does the conversion. + return LoweredValInfo(); + } + LoweredValInfo visitTransitiveSubtypeWitness(TransitiveSubtypeWitness* val) { // The base (subToMid) will turn into a value with diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 82cb8caf3..aec3b4e90 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1721,6 +1721,20 @@ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericP constraint->sup = parser->ParseTypeExp(); AddMember(genericParent, constraint); } + else if (AdvanceIf(parser, TokenType::LParent)) + { + auto constraint = parser->astBuilder->create<TypeCoercionConstraintDecl>(); + constraint->whereTokenLoc = whereToken.loc; + parser->FillPosition(constraint); + constraint->toType = subType; + constraint->fromType = parser->ParseTypeExp(); + parser->ReadToken(TokenType::RParent); + if (AdvanceIf(parser, "implicit")) + { + addModifier(constraint, parser->astBuilder->create<ImplicitConversionModifier>()); + } + AddMember(genericParent, constraint); + } } } @@ -8910,8 +8924,19 @@ static NodeBase* parseImplicitConversionModifier(Parser* parser, void* /*userDat ConversionCost cost = kConversionCost_Default; if (AdvanceIf(parser, TokenType::LParent)) { - cost = - ConversionCost(stringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent())); + if (AdvanceIf(parser, "constraint")) + { + cost = kConversionCost_TypeCoercionConstraint; + if (AdvanceIf(parser, TokenType::OpAdd)) + { + cost = kConversionCost_TypeCoercionConstraintPlusScalarToVector; + } + } + else + { + cost = ConversionCost( + stringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent())); + } if (AdvanceIf(parser, TokenType::Comma)) { builtinKind = BuiltinConversionKind( |
