summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/core.meta.slang93
-rw-r--r--source/slang/slang-ast-builder.cpp8
-rw-r--r--source/slang/slang-ast-builder.h5
-rw-r--r--source/slang/slang-ast-decl.h9
-rw-r--r--source/slang/slang-ast-modifier.h4
-rw-r--r--source/slang/slang-ast-support-types.h5
-rw-r--r--source/slang/slang-ast-val.cpp52
-rw-r--r--source/slang/slang-ast-val.h14
-rw-r--r--source/slang/slang-check-constraint.cpp59
-rw-r--r--source/slang/slang-check-conversion.cpp29
-rw-r--r--source/slang/slang-check-decl.cpp12
-rw-r--r--source/slang/slang-check-impl.h5
-rw-r--r--source/slang/slang-check-overload.cpp9
-rw-r--r--source/slang/slang-ir-constexpr.cpp9
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir-peephole.cpp62
-rw-r--r--source/slang/slang-ir.cpp29
-rw-r--r--source/slang/slang-lower-to-ir.cpp7
-rw-r--r--source/slang/slang-parser.cpp29
20 files changed, 349 insertions, 94 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index b7f50cd1b..267f7b2d4 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1,4 +1,4 @@
-//public module core;
+public module core;
// Slang `core` library
@@ -2367,49 +2367,6 @@ __generic<T> __extension vector<T, 4>
${{{{
-// The above extensions are generic in the *type* of the vector,
-// but explicit in the *size*. We will now declare an extension
-// for each builtin type that is generic in the size.
-//
-for (int tt = 0; tt < kBaseTypeCount; ++tt)
-{
- if(kBaseTypes[tt].tag == BaseType::Void) continue;
-
- sb << "__generic<let N : int> __extension vector<"
- << kBaseTypes[tt].name << ",N>\n{\n";
-
- for (int ff = 0; ff < kBaseTypeCount; ++ff)
- {
- if(kBaseTypes[ff].tag == BaseType::Void) continue;
-
-
- if( tt != ff )
- {
- auto cost = getBaseTypeConversionCost(
- kBaseTypes[tt],
- kBaseTypes[ff]);
- auto op = getBaseTypeConversionOp(
- kBaseTypes[tt],
- kBaseTypes[ff]);
-
- // Implicit conversion from a vector of the same
- // size, but different element type.
- sb << " __implicit_conversion(" << cost << ")\n";
- sb << " __intrinsic_op(" << int(op) << ")\n";
- sb << " __init(vector<" << kBaseTypes[ff].name << ",N> value);\n";
-
- // Constructor to make a vector from a scalar of another type.
- if (cost != kConversionCost_Impossible)
- {
- cost += kConversionCost_ScalarToVector;
- sb << " __implicit_conversion(" << cost << ")\n";
- sb << " [__unsafeForceInlineEarly]\n";
- sb << " __init(" << kBaseTypes[ff].name << " value) { this = vector<" << kBaseTypes[tt].name << ",N>( " << kBaseTypes[tt].name << "(value)); }\n";
- }
- }
- }
- sb << "}\n";
-}
for( int R = 1; R <= 4; ++R )
for( int C = 1; C <= 4; ++C )
@@ -2464,38 +2421,36 @@ for( int C = 1; C <= 4; ++C )
sb << "}\n";
}
-for (int tt = 0; tt < kBaseTypeCount; ++tt)
-{
- if(kBaseTypes[tt].tag == BaseType::Void) continue;
- auto toType = kBaseTypes[tt].name;
}}}}
-__generic<let R : int, let C : int, let L : int> extension matrix<$(toType),R,C,L>
+//@hidden:
+__intrinsic_op($(kIROp_BuiltinCast))
+internal T __builtin_cast<T, U>(U u);
+
+// If T is implicitly convertible to U, then vector<T,N> is implicitly convertible to vector<U,N>.
+__generic<ToType, let N : int> extension vector<ToType,N>
{
-${{{{
- for (int ff = 0; ff < kBaseTypeCount; ++ff)
- {
- if(kBaseTypes[ff].tag == BaseType::Void) continue;
- if( tt == ff ) continue;
+ __implicit_conversion(constraint)
+ __intrinsic_op(BuiltinCast)
+ __init<FromType>(vector<FromType,N> value) where ToType(FromType) implicit;
- auto cost = getBaseTypeConversionCost(
- kBaseTypes[tt],
- kBaseTypes[ff]);
- auto fromType = kBaseTypes[ff].name;
- auto op = getBaseTypeConversionOp(
- kBaseTypes[tt],
- kBaseTypes[ff]);
-}}}}
- __implicit_conversion($(cost))
- __intrinsic_op($(op))
- __init(matrix<$(fromType),R,C,L> value);
-${{{{
+ __implicit_conversion(constraint+)
+ [__unsafeForceInlineEarly]
+ [__readNone]
+ [TreatAsDifferentiable]
+ __init<FromType>(FromType value) where ToType(FromType) implicit
+ {
+ this = __builtin_cast<vector<ToType,N>>(vector<FromType,N>(value));
}
-}}}}
}
-${{{{
+
+// If T is implicitly convertible to U, then matrix<T,R,C,L> is implicitly convertible to matrix<U,R,C,L>.
+__generic<ToType, let R : int, let C : int, let L : int> extension matrix<ToType,R,C,L>
+{
+ __implicit_conversion(constraint)
+ __intrinsic_op(BuiltinCast)
+ __init<FromType>(matrix<FromType,R,C,L> value) where ToType(FromType) implicit;
}
-}}}}
//@ hidden:
__generic<T, U>
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 6ffaee7db..b3afa5310 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -948,6 +948,14 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness(
return witness;
}
+TypeCoercionWitness* ASTBuilder::getTypeCoercionWitness(
+ Type* subType,
+ Type* superType,
+ DeclRef<Decl> declRef)
+{
+ return getOrCreate<TypeCoercionWitness>(subType, superType, declRef.declRefBase);
+}
+
DeclRef<Decl> _getMemberDeclRef(ASTBuilder* builder, DeclRef<Decl> parent, Decl* decl)
{
return builder->getMemberDeclRef(parent, decl);
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index cae380e40..67dfaaf52 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -636,6 +636,11 @@ public:
SubtypeWitness* subIsLWitness,
SubtypeWitness* subIsRWitness);
+ TypeCoercionWitness* getTypeCoercionWitness(
+ Type* fromType,
+ Type* toType,
+ DeclRef<Decl> declRef);
+
/// Helpers to get type info from the SharedASTBuilder
const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice)
{
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index ff8e5684a..ff55340ac 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -612,6 +612,15 @@ class GenericTypeConstraintDecl : public TypeConstraintDecl
const TypeExp& _getSupOverride() const { return sup; }
};
+class TypeCoercionConstraintDecl : public Decl
+{
+ SLANG_AST_CLASS(TypeCoercionConstraintDecl)
+
+ SourceLoc whereTokenLoc = SourceLoc();
+ TypeExp fromType;
+ TypeExp toType;
+};
+
class GenericValueParamDecl : public VarDeclBase
{
SLANG_AST_CLASS(GenericValueParamDecl)
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index cc4901236..e4d5ccd09 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1279,10 +1279,10 @@ class ImplicitConversionModifier : public Modifier
SLANG_AST_CLASS(ImplicitConversionModifier)
// The conversion cost, used to rank conversions
- ConversionCost cost;
+ ConversionCost cost = kConversionCost_None;
// A builtin identifier for identifying conversions that need special treatment.
- BuiltinConversionKind builtinConversionKind;
+ BuiltinConversionKind builtinConversionKind = kBuiltinConversion_Unknown;
};
class FormatAttribute : public Attribute
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index d24007721..b3baee98f 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -178,6 +178,11 @@ enum : ConversionCost
// Additional cost when casting an LValue.
kConversionCost_LValueCast = 800,
+ // The cost of this conversion is defined by the type coercion constraint.
+ kConversionCost_TypeCoercionConstraint = 1000,
+ kConversionCost_TypeCoercionConstraintPlusScalarToVector =
+ kConversionCost_TypeCoercionConstraint + kConversionCost_ScalarToVector,
+
// Conversion is impossible
kConversionCost_Impossible = 0xFFFFFFFF,
};
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 9bcfd21bc..7613dbe80 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -845,6 +845,58 @@ void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
out << ")";
}
+void TypeCoercionWitness::_toTextOverride(StringBuilder& out)
+{
+ out << "TypeCoercionWitness(";
+ if (getFromType())
+ out << getFromType();
+ if (getToType())
+ out << getToType();
+ out << ")";
+}
+
+Val* TypeCoercionWitness::_substituteImplOverride(
+ ASTBuilder* astBuilder,
+ SubstitutionSet subst,
+ int* ioDiff)
+{
+ int diff = 0;
+
+ auto substDeclRef = getDeclRef().substituteImpl(astBuilder, subst, &diff);
+ auto substFrom = as<Type>(getFromType()->substituteImpl(astBuilder, subst, &diff));
+ auto substTo = as<Type>(getToType()->substituteImpl(astBuilder, subst, &diff));
+
+ if (!diff)
+ return this;
+
+ (*ioDiff)++;
+
+ TypeCoercionWitness* substValue =
+ astBuilder->getTypeCoercionWitness(substFrom, substTo, substDeclRef);
+ return substValue;
+}
+
+Val* TypeCoercionWitness::_resolveImplOverride()
+{
+ Val* resolvedDeclRef = nullptr;
+ if (getDeclRef())
+ resolvedDeclRef = getDeclRef().declRefBase->resolve();
+ if (auto resolvedVal = as<Witness>(resolvedDeclRef))
+ return resolvedVal;
+
+ auto newFrom = as<Type>(getFromType()->resolve());
+ auto newTo = as<Type>(getToType()->resolve());
+
+ auto newDeclRef = as<DeclRefBase>(resolvedDeclRef);
+ if (!newDeclRef)
+ newDeclRef = getDeclRef().declRefBase;
+ if (newFrom != getFromType() || newTo != getToType() || newDeclRef != getDeclRef())
+ {
+ return getCurrentASTBuilder()->getTypeCoercionWitness(newFrom, newTo, newDeclRef);
+ }
+ return this;
+}
+
// UNormModifierVal
void UNormModifierVal::_toTextOverride(StringBuilder& out)
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 7b33a8111..3a14be17b 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -621,6 +621,20 @@ class TypeEqualityWitness : public SubtypeWitness
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
+class TypeCoercionWitness : public Witness
+{
+ SLANG_AST_CLASS(TypeCoercionWitness)
+
+ Type* getFromType() { return as<Type>(getOperand(0)); }
+ Type* getToType() { return as<Type>(getOperand(1)); }
+
+ DeclRef<Decl> getDeclRef() { return as<DeclRefBase>(getOperand(2)); }
+
+ void _toTextOverride(StringBuilder& out);
+ Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ Val* _resolveImplOverride();
+};
+
// A witness that one type is a subtype of another
// because some in-scope declaration says so
class DeclaredSubtypeWitness : public SubtypeWitness
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 872d2616c..642a4bf6a 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -715,13 +715,6 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
// system as being solved now, as a result of the witness we found.
}
- // Add a flat cost to all unconstrained generic params.
- for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
- {
- if (!constrainedGenericParams.contains(typeParamDecl))
- outBaseCost += kConversionCost_UnconstraintGenericParam;
- }
-
// Make sure we haven't constructed any spurious constraints
// that we aren't able to satisfy:
for (auto c : system->constraints)
@@ -732,6 +725,58 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
}
}
+ // Verify that all type coercion constraints can be satisfied.
+ for (auto constraintDecl :
+ genericDeclRef.getDecl()->getMembersOfType<TypeCoercionConstraintDecl>())
+ {
+ DeclRef<TypeCoercionConstraintDecl> constraintDeclRef =
+ m_astBuilder
+ ->getGenericAppDeclRef(
+ genericDeclRef,
+ args.getArrayView().arrayView,
+ constraintDecl)
+ .as<TypeCoercionConstraintDecl>();
+ auto fromType = constraintDeclRef.substitute(m_astBuilder, constraintDecl->fromType.Ptr());
+ auto toType = constraintDeclRef.substitute(m_astBuilder, constraintDecl->toType.Ptr());
+ auto conversionCost = getConversionCost(toType, fromType);
+ if (constraintDecl->findModifier<ImplicitConversionModifier>())
+ {
+ if (conversionCost > kConversionCost_GeneralConversion)
+ {
+ // The type arguments are not implicitly convertible, return failure.
+ return DeclRef<Decl>();
+ }
+ }
+ else
+ {
+ if (conversionCost == kConversionCost_Impossible)
+ {
+ // The type arguments are not convertible, return failure.
+ return DeclRef<Decl>();
+ }
+ }
+ if (auto fromDecl = isDeclRefTypeOf<Decl>(constraintDecl->fromType))
+ {
+ constrainedGenericParams.add(fromDecl.getDecl());
+ }
+ if (auto toDecl = isDeclRefTypeOf<Decl>(constraintDecl->toType))
+ {
+ constrainedGenericParams.add(toDecl.getDecl());
+ }
+ // If we are to expand the support of type coercion constraint beyond simple builtin core
+ // module functions, then the witness should be a reference to the conversion function. For
+ // now, this isn't required, and it is not easy to get it from the coercion logic, so we
+ // leave it empty.
+ args.add(m_astBuilder->getTypeCoercionWitness(fromType, toType, DeclRef<Decl>()));
+ }
+
+ // Add a flat cost to all unconstrained generic params.
+ for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
+ {
+ if (!constrainedGenericParams.contains(typeParamDecl))
+ outBaseCost += kConversionCost_UnconstraintGenericParam;
+ }
+
return m_astBuilder->getGenericAppDeclRef(genericDeclRef, args.getArrayView().arrayView);
}
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 6dda9c1ea..a9785a585 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -1045,11 +1045,28 @@ int getTypeBitSize(Type* t)
}
ConversionCost SemanticsVisitor::getImplicitConversionCostWithKnownArg(
- Decl* decl,
+ DeclRef<Decl> decl,
Type* toType,
Expr* arg)
{
- ConversionCost candidateCost = getImplicitConversionCost(decl);
+ ConversionCost candidateCost = getImplicitConversionCost(decl.getDecl());
+
+ if (candidateCost == kConversionCost_TypeCoercionConstraint ||
+ candidateCost == kConversionCost_TypeCoercionConstraintPlusScalarToVector)
+ {
+ if (auto genApp = as<GenericAppDeclRef>(decl.declRefBase))
+ {
+ for (auto genArg : genApp->getArgs())
+ {
+ if (auto wit = as<TypeCoercionWitness>(genArg))
+ {
+ candidateCost -= kConversionCost_TypeCoercionConstraint;
+ candidateCost += getConversionCost(wit->getToType(), wit->getFromType());
+ break;
+ }
+ }
+ }
+ }
// Fix up the cost if the operand is a const lit.
if (isScalarIntegerType(toType))
@@ -1577,10 +1594,8 @@ bool SemanticsVisitor::_coerce(
ImplicitCastMethod method;
for (auto candidate : overloadContext.bestCandidates)
{
- ConversionCost candidateCost = getImplicitConversionCostWithKnownArg(
- candidate.item.declRef.getDecl(),
- toType,
- fromExpr);
+ ConversionCost candidateCost =
+ getImplicitConversionCostWithKnownArg(candidate.item.declRef, toType, fromExpr);
if (candidateCost < bestCost)
{
method.conversionFuncOverloadCandidate = candidate;
@@ -1632,7 +1647,7 @@ bool SemanticsVisitor::_coerce(
// cost associated with the initializer we are invoking.
//
ConversionCost cost = getImplicitConversionCostWithKnownArg(
- overloadContext.bestCandidate->item.declRef.getDecl(),
+ overloadContext.bestCandidate->item.declRef,
toType,
fromExpr);
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 1ef5b1cec..5b5e05b73 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -147,6 +147,8 @@ struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase,
void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl);
+ void visitTypeCoercionConstraintDecl(TypeCoercionConstraintDecl* decl);
+
void validateGenericConstraintSubType(GenericTypeConstraintDecl* decl, TypeExp type);
void visitGenericDecl(GenericDecl* genericDecl);
@@ -2911,6 +2913,16 @@ void SemanticsDeclHeaderVisitor::validateGenericConstraintSubType(
}
}
+void SemanticsDeclHeaderVisitor::visitTypeCoercionConstraintDecl(TypeCoercionConstraintDecl* decl)
+{
+ CheckConstraintSubType(decl->toType);
+
+ if (!decl->fromType.type)
+ decl->fromType = TranslateTypeNodeForced(decl->fromType);
+ if (!decl->toType.type)
+ decl->toType = TranslateTypeNodeForced(decl->toType);
+}
+
void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl)
{
// TODO: are there any other validations we can do at this point?
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 59290f8ad..6438a91e3 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1511,7 +1511,10 @@ public:
// perform implicit type conversion.
ConversionCost getImplicitConversionCost(Decl* decl);
- ConversionCost getImplicitConversionCostWithKnownArg(Decl* decl, Type* toType, Expr* arg);
+ ConversionCost getImplicitConversionCostWithKnownArg(
+ DeclRef<Decl> decl,
+ Type* toType,
+ Expr* arg);
BuiltinConversionKind getImplicitConversionBuiltinKind(Decl* decl);
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index b944d2bf4..b75f95f9a 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1675,6 +1675,15 @@ int SemanticsVisitor::CompareOverloadCandidates(OverloadCandidate* left, Overloa
if (itemDiff)
return itemDiff;
+ // If one candidate is an implicit conversion, and other candidate is not,
+ // then we should prefer the implicit conversion.
+ int leftIsImplicitConversion =
+ left->item.declRef.getDecl()->findModifier<ImplicitConversionModifier>() ? 1 : 0;
+ int rightIsImplicitConversion =
+ right->item.declRef.getDecl()->findModifier<ImplicitConversionModifier>() ? 1 : 0;
+ if (leftIsImplicitConversion != rightIsImplicitConversion)
+ return rightIsImplicitConversion - leftIsImplicitConversion;
+
auto specificityDiff = compareOverloadCandidateSpecificity(left->item, right->item);
if (specificityDiff)
return specificityDiff;
diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp
index ff6d64319..620c65d4e 100644
--- a/source/slang/slang-ir-constexpr.cpp
+++ b/source/slang/slang-ir-constexpr.cpp
@@ -116,6 +116,7 @@ bool opCanBeConstExpr(IROp op)
case kIROp_PtrCast:
case kIROp_Reinterpret:
case kIROp_BitCast:
+ case kIROp_BuiltinCast:
case kIROp_MakeTuple:
case kIROp_MakeDifferentialPair:
case kIROp_MakeExistential:
@@ -178,7 +179,13 @@ bool opCanBeConstExprByBackwardPass(IRInst* value)
{
if (value->getOp() == kIROp_Param)
return isLoopPhi(as<IRParam, IRDynamicCastBehavior::NoUnwrap>(value));
- return opCanBeConstExpr(value->getOp());
+ if (opCanBeConstExpr(value->getOp()))
+ return true;
+ if (auto callInst = as<IRCall>(value))
+ {
+ return !callInst->mightHaveSideEffects();
+ }
+ return false;
}
void markConstExpr(PropagateConstExprContext* context, IRInst* value)
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 55880eab5..5a1966d00 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -1202,6 +1202,7 @@ INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, HOIST
INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0)
INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0)
+INST(BuiltinCast, BuiltinCast, 1, 0)
INST(BitCast, bitCast, 1, 0)
INST(Reinterpret, reinterpret, 1, 0)
INST(Unmodified, unmodified, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index dbefa68c7..d64820aa6 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4025,7 +4025,7 @@ public:
/// the inst.
IRInst* emitDefaultConstructRaw(IRType* type);
- IRInst* emitCast(IRType* type, IRInst* value);
+ IRInst* emitCast(IRType* type, IRInst* value, bool fallbackToBuiltinCast = true);
IRInst* emitVectorReshape(IRType* type, IRInst* value);
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index fc399954b..e29fdf975 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -98,6 +98,7 @@ struct PeepholeContext : InstPassBase
else if (remainingKeys.getCount() > 0)
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
builder.setInsertBefore(inst);
auto newValue =
builder.emitElementExtract(updateInst->getElementValue(), remainingKeys);
@@ -112,6 +113,7 @@ struct PeepholeContext : InstPassBase
// accessChain!=accessChain2, then we can replace the inst with extract(x,
// accessChain2).
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
builder.setInsertBefore(inst);
auto newInst =
builder.emitElementExtract(updateInst->getOldValue(), chainKey.getArrayView());
@@ -140,6 +142,8 @@ struct PeepholeContext : InstPassBase
if (vectorType->getElementType() != replacement->getFullType())
return false;
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
replacement =
builder.emitMakeVectorFromScalar(inst->getFullType(), replacement);
@@ -175,6 +179,7 @@ struct PeepholeContext : InstPassBase
else if (inst->getOperand(0) == inst->getOperand(1))
{
IRBuilder builder(inst);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
builder.setInsertBefore(inst);
return tryReplace(builder.emitDefaultConstruct(inst->getDataType()));
}
@@ -280,6 +285,8 @@ struct PeepholeContext : InstPassBase
break;
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
IRInst* resultVal = nullptr;
if (inst->getOp() == kIROp_AlignOf)
@@ -319,6 +326,8 @@ struct PeepholeContext : InstPassBase
if (inst->getOperand(0)->getOp() == kIROp_MakeResultError)
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
inst->replaceUsesWith(builder.getBoolValue(true));
maybeRemoveOldInst(inst);
changed = true;
@@ -326,6 +335,8 @@ struct PeepholeContext : InstPassBase
else if (inst->getOperand(0)->getOp() == kIROp_MakeResultValue)
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
inst->replaceUsesWith(builder.getBoolValue(false));
maybeRemoveOldInst(inst);
changed = true;
@@ -359,6 +370,8 @@ struct PeepholeContext : InstPassBase
if (const auto packType = as<IRTypePack>(pack->getDataType()))
{
IRBuilder builder(inst);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
List<IRInst*> args;
for (UInt j = 0; j < packType->getOperandCount(); ++j)
@@ -443,6 +456,8 @@ struct PeepholeContext : InstPassBase
index->getValue() < startIndex + vecSize->getValue())
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto newElement = builder.emitElementExtract(
element,
@@ -517,6 +532,8 @@ struct PeepholeContext : InstPassBase
if (args.getCount() == arraySize->getValue())
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto makeArray = builder.emitMakeArray(
arrayType,
@@ -573,6 +590,8 @@ struct PeepholeContext : InstPassBase
if (isComplete)
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto makeArray = builder.emitMakeArray(
arrayType,
@@ -618,6 +637,8 @@ struct PeepholeContext : InstPassBase
if (isValid)
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto makeStruct = builder.emitMakeStruct(
structType,
@@ -678,6 +699,8 @@ struct PeepholeContext : InstPassBase
// Create a makeStruct inst using args.
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto makeStruct = builder.emitMakeStruct(
structType,
@@ -694,6 +717,8 @@ struct PeepholeContext : InstPassBase
{
auto ptr = inst->getOperand(0);
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto neq = builder.emitNeq(ptr, builder.getNullPtrValue(ptr->getDataType()));
inst->replaceUsesWith(neq);
@@ -708,6 +733,8 @@ struct PeepholeContext : InstPassBase
if (isTypeEqual(actualType, (IRType*)isTypeInst->getTypeOperand()))
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto trueVal = builder.getBoolValue(true);
inst->replaceUsesWith(trueVal);
@@ -770,6 +797,7 @@ struct PeepholeContext : InstPassBase
if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue)
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
builder.setInsertBefore(inst);
auto trueVal = builder.getBoolValue(true);
inst->replaceUsesWith(trueVal);
@@ -779,6 +807,8 @@ struct PeepholeContext : InstPassBase
else if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalNone)
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto falseVal = builder.getBoolValue(false);
inst->replaceUsesWith(falseVal);
@@ -841,6 +871,7 @@ struct PeepholeContext : InstPassBase
case kIROp_DefaultConstruct:
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
builder.setInsertBefore(inst);
// See if we can replace the default construct inst with concrete values.
if (auto newCtor = builder.emitDefaultConstruct(inst->getFullType(), false))
@@ -851,6 +882,21 @@ struct PeepholeContext : InstPassBase
}
}
break;
+ case kIROp_BuiltinCast:
+ {
+ IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+ builder.setInsertBefore(inst);
+ // See if we can replace the default construct inst with concrete values.
+ if (auto newCast =
+ builder.emitCast(inst->getFullType(), inst->getOperand(0), false))
+ {
+ inst->replaceUsesWith(newCast);
+ maybeRemoveOldInst(inst);
+ changed = true;
+ }
+ }
+ break;
case kIROp_VectorReshape:
{
auto fromType = as<IRVectorType>(inst->getOperand(0)->getDataType());
@@ -867,6 +913,7 @@ struct PeepholeContext : InstPassBase
break;
}
IRBuilder builder(inst);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
builder.setInsertBefore(inst);
UInt index = 0;
auto newInst = builder.emitSwizzle(resultType, inst->getOperand(0), 1, &index);
@@ -882,6 +929,8 @@ struct PeepholeContext : InstPassBase
if (!toCount)
break;
IRBuilder builder(inst);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto newInst = builder.emitVectorReshape(resultType, inst->getOperand(0));
if (newInst != inst)
@@ -911,6 +960,7 @@ struct PeepholeContext : InstPassBase
break;
List<IRInst*> rows;
IRBuilder builder(inst);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
builder.setInsertBefore(inst);
auto toRowType = builder.getVectorType(
resultType->getElementType(),
@@ -1035,6 +1085,8 @@ struct PeepholeContext : InstPassBase
break;
}
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto newInst =
builder.emitMakeVectorFromScalar(vectorType, inst->getOperand(0));
@@ -1075,6 +1127,8 @@ struct PeepholeContext : InstPassBase
else
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto newMakeVector = builder.emitMakeVector(
swizzle->getDataType(),
@@ -1100,6 +1154,8 @@ struct PeepholeContext : InstPassBase
if (isConcreteType(left) && isConcreteType(right))
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
bool result = left == right;
inst->replaceUsesWith(builder.getBoolValue(result));
@@ -1123,6 +1179,8 @@ struct PeepholeContext : InstPassBase
if (!SLANG_SUCCEEDED(res))
break;
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto stride =
builder.getIntValue(inst->getDataType(), sizeAlignment.getStride());
@@ -1148,6 +1206,8 @@ struct PeepholeContext : InstPassBase
if (isConcreteType(type))
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
bool result = false;
switch (inst->getOp())
@@ -1186,6 +1246,8 @@ struct PeepholeContext : InstPassBase
if (as<IRLoad>(inst)->getPtr()->getOp() == kIROp_undefined)
{
IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+
builder.setInsertBefore(inst);
auto undef = builder.emitUndefined(inst->getDataType());
inst->replaceUsesWith(undef);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index cdabb1ac2..f28f61ffc 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3995,7 +3995,7 @@ static TypeCastStyle _getTypeStyleId(IRType* type)
}
}
-IRInst* IRBuilder::emitCast(IRType* type, IRInst* value)
+IRInst* IRBuilder::emitCast(IRType* type, IRInst* value, bool fallbackToBuiltinCast)
{
if (isTypeEqual(type, value->getDataType()))
return value;
@@ -4009,8 +4009,17 @@ IRInst* IRBuilder::emitCast(IRType* type, IRInst* value)
SLANG_UNREACHABLE("cast from void type");
}
- SLANG_RELEASE_ASSERT(toStyle != TypeCastStyle::Unknown);
- SLANG_RELEASE_ASSERT(fromStyle != TypeCastStyle::Unknown);
+ if (toStyle == TypeCastStyle::Unknown || fromStyle == TypeCastStyle::Unknown)
+ {
+ if (fallbackToBuiltinCast)
+ {
+ return emitIntrinsicInst(type, kIROp_BuiltinCast, 1, &value);
+ }
+ else
+ {
+ return nullptr;
+ }
+ }
struct OpSeq
{
@@ -4057,7 +4066,18 @@ IRInst* IRBuilder::emitCast(IRType* type, IRInst* value)
auto t = type;
if (op.op1 != kIROp_Nop)
{
- t = getUInt64Type();
+ if (toStyle == TypeCastStyle::Bool)
+ t = getIntType();
+ else
+ t = getUInt64Type();
+ if (auto vecType = as<IRVectorType>(type))
+ t = getVectorType(t, vecType->getElementCount());
+ else if (auto matType = as<IRMatrixType>(type))
+ t = getMatrixType(
+ t,
+ matType->getRowCount(),
+ matType->getColumnCount(),
+ matType->getLayout());
}
auto result = emitIntrinsicInst(t, op.op0, 1, &value);
if (op.op1 != kIROp_Nop)
@@ -8293,6 +8313,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options)
case kIROp_ExtractExistentialValue:
case kIROp_ExtractExistentialWitnessTable:
case kIROp_WrapExistential:
+ case kIROp_BuiltinCast:
case kIROp_BitCast:
case kIROp_CastFloatToInt:
case kIROp_CastIntToFloat:
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ed8a52b9e..fbe6d8a84 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1737,6 +1737,13 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
context->irBuilder->getTypeEqualityWitness(witnessType, subType, supType));
}
+ LoweredValInfo visitTypeCoercionWitness(TypeCoercionWitness*)
+ {
+ // When we fully support type coercion constraints, we should lower the witness into a
+ // function that does the conversion.
+ return LoweredValInfo();
+ }
+
LoweredValInfo visitTransitiveSubtypeWitness(TransitiveSubtypeWitness* val)
{
// The base (subToMid) will turn into a value with
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 82cb8caf3..aec3b4e90 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -1721,6 +1721,20 @@ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericP
constraint->sup = parser->ParseTypeExp();
AddMember(genericParent, constraint);
}
+ else if (AdvanceIf(parser, TokenType::LParent))
+ {
+ auto constraint = parser->astBuilder->create<TypeCoercionConstraintDecl>();
+ constraint->whereTokenLoc = whereToken.loc;
+ parser->FillPosition(constraint);
+ constraint->toType = subType;
+ constraint->fromType = parser->ParseTypeExp();
+ parser->ReadToken(TokenType::RParent);
+ if (AdvanceIf(parser, "implicit"))
+ {
+ addModifier(constraint, parser->astBuilder->create<ImplicitConversionModifier>());
+ }
+ AddMember(genericParent, constraint);
+ }
}
}
@@ -8910,8 +8924,19 @@ static NodeBase* parseImplicitConversionModifier(Parser* parser, void* /*userDat
ConversionCost cost = kConversionCost_Default;
if (AdvanceIf(parser, TokenType::LParent))
{
- cost =
- ConversionCost(stringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent()));
+ if (AdvanceIf(parser, "constraint"))
+ {
+ cost = kConversionCost_TypeCoercionConstraint;
+ if (AdvanceIf(parser, TokenType::OpAdd))
+ {
+ cost = kConversionCost_TypeCoercionConstraintPlusScalarToVector;
+ }
+ }
+ else
+ {
+ cost = ConversionCost(
+ stringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent()));
+ }
if (AdvanceIf(parser, TokenType::Comma))
{
builtinKind = BuiltinConversionKind(