summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-10-20 09:28:13 -0700
committerGitHub <noreply@github.com>2024-10-20 09:28:13 -0700
commit307315a7305e76529837fd1cdb677f534d5f539b (patch)
treeba39e96ba2e9b3d62d1213aab2f1cc54febe451a /source
parent9936178dd3efb026bfa142512a2bf061d7a75ab5 (diff)
Properly check switch case. (#5341)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-expr.h8
-rw-r--r--source/slang/slang-ast-iterator.h4
-rw-r--r--source/slang/slang-check-conversion.cpp18
-rw-r--r--source/slang/slang-check-expr.cpp20
-rw-r--r--source/slang/slang-check-impl.h3
-rw-r--r--source/slang/slang-check-overload.cpp33
-rw-r--r--source/slang/slang-check-stmt.cpp22
-rw-r--r--source/slang/slang-emit-c-like.cpp7
-rw-r--r--source/slang/slang-emit-c-like.h4
-rw-r--r--source/slang/slang-emit-wgsl.cpp35
-rw-r--r--source/slang/slang-emit-wgsl.h5
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp5
-rw-r--r--source/slang/slang-lower-to-ir.cpp7
-rw-r--r--source/slang/slang-syntax.h5
14 files changed, 107 insertions, 69 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index f9b8831f1..8b779e8db 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -329,6 +329,14 @@ class ImplicitCastExpr : public TypeCastExpr
SLANG_AST_CLASS(ImplicitCastExpr)
};
+// A builtin cast expr generated during semantic checking, where there is
+// no associated conversion function decl.
+class BuiltinCastExpr : public Expr
+{
+ SLANG_AST_CLASS(BuiltinCastExpr);
+ Expr* base = nullptr;
+};
+
class LValueImplicitCastExpr : public TypeCastExpr
{
SLANG_AST_CLASS(LValueImplicitCastExpr)
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index 24f98391c..c6f74fdf8 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -73,6 +73,10 @@ struct ASTIterator
dispatchIfNotNull(arg);
}
+ void visitBuiltinCastExpr(BuiltinCastExpr* expr)
+ {
+ dispatchIfNotNull(expr->base);
+ }
void visitParenExpr(ParenExpr* expr)
{
iterator->maybeDispatchCallback(expr);
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index c0d7feaff..586f44887 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -960,7 +960,11 @@ namespace Slang
}
if (outToExpr)
{
- *outToExpr = fromExpr;
+ auto rsExpr = getASTBuilder()->create<BuiltinCastExpr>();
+ rsExpr->type = toType;
+ rsExpr->loc = fromExpr->loc;
+ rsExpr->base = fromExpr;
+ *outToExpr = rsExpr;
}
return true;
}
@@ -1150,7 +1154,7 @@ namespace Slang
// call to one of the initializers in the target type.
OverloadResolveContext overloadContext;
- overloadContext.disallowNestedConversions = true;
+ overloadContext.disallowNestedConversions = (site != CoercionSite::ExplicitCoercion);
overloadContext.argCount = 1;
List<Expr*> args;
args.add(fromExpr);
@@ -1295,7 +1299,7 @@ namespace Slang
// but then emit a diagnostic when actually reifying
// the result expression.
//
- if (outToExpr)
+ if (outToExpr && site != CoercionSite::ExplicitCoercion)
{
if (cost >= kConversionCost_Explicit)
{
@@ -1362,7 +1366,9 @@ namespace Slang
// base expression (the callee), since that will come
// from the selected overload candidate.
//
- auto castExpr = createImplicitCastExpr();
+ InvokeExpr* castExpr = (site == CoercionSite::ExplicitCoercion)
+ ? m_astBuilder->create<ExplicitCastExpr>()
+ : createImplicitCastExpr();
castExpr->loc = fromExpr->loc;
castExpr->arguments.add(fromExpr);
//
@@ -1379,7 +1385,7 @@ namespace Slang
// "argument list" was just a pointer to `fromExpr`.
//
// That means we need to clear the argument list and
- // reload it from `fromExpr` to make sure that we
+ // reload it from `args[0]` to make sure that we
// got the arguments *after* any transformations
// were applied.
// For right now this probably doesn't matter,
@@ -1387,7 +1393,7 @@ namespace Slang
// but I'd rather play it safe.
//
castExpr->arguments.clear();
- castExpr->arguments.add(fromExpr);
+ castExpr->arguments.add(args[0]);
}
if (!cachedMethod)
getShared()->cacheImplicitCastMethod(implicitCastKey, ImplicitCastMethod{ *overloadContext.bestCandidate, cost });
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index ec4182059..41cbf689b 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1979,17 +1979,23 @@ namespace Slang
}
}
- if(auto castExpr = expr.as<TypeCastExpr>())
+ SubstExpr<Expr> typeCastOperand;
+ if (auto typeCastExpr = expr.as<TypeCastExpr>())
+ typeCastOperand = getArg(typeCastExpr, 0);
+ else if (auto builtinCastExpr = expr.as<BuiltinCastExpr>())
+ typeCastOperand = getBaseExpr(builtinCastExpr);
+
+ if (typeCastOperand)
{
auto substType = getType(m_astBuilder, expr);
if (!substType)
return nullptr;
if (!isValidCompileTimeConstantType(substType))
return nullptr;
- auto val = tryConstantFoldExpr(getArg(castExpr, 0), kind, circularityInfo);
+ auto val = tryConstantFoldExpr(typeCastOperand, kind, circularityInfo);
if (val)
{
- if (!castExpr.getExpr()->type)
+ if (!expr.getExpr()->type)
return nullptr;
auto foldVal = as<IntVal>(
TypeCastIntVal::tryFoldImpl(m_astBuilder, substType, val, getSink()));
@@ -2105,6 +2111,8 @@ namespace Slang
case IntegerConstantExpressionCoercionType::AnyInteger:
if (isScalarIntegerType(inExpr->type))
expr = inExpr;
+ else if (isEnumType(inExpr->type))
+ expr = inExpr;
else
expr = coerce(CoercionSite::General, m_astBuilder->getIntType(), inExpr);
break;
@@ -3491,6 +3499,12 @@ namespace Slang
return sizeOfLikeExpr;
}
+ Expr* SemanticsExprVisitor::visitBuiltinCastExpr(BuiltinCastExpr* expr)
+ {
+ // All builtin cast exprs should already be checked.
+ return expr;
+ }
+
Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr)
{
if (expr->type)
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 5d70c36d5..cecae31fa 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -292,6 +292,7 @@ namespace Slang
Argument,
Return,
Initializer,
+ ExplicitCoercion
};
struct FacetImpl;
@@ -2818,6 +2819,8 @@ namespace Slang
Expr* visitTypeCastExpr(TypeCastExpr * expr);
+ Expr* visitBuiltinCastExpr(BuiltinCastExpr* expr);
+
Expr* visitTryExpr(TryExpr* expr);
Expr* visitIsTypeExpr(IsTypeExpr* expr);
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index d0441397d..41bac4bb0 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -2097,7 +2097,7 @@ namespace Slang
// by doing so we could weed out cases where a type is "constructed"
// from a value of the same type. There is no need in Slang for
// "copy constructors" but the stdlib currently has to define
- // some just to make code that does, e.g., `float(1.0f)` work.
+ // some just to make code that does, e.g., `float(1.0f)` work.)
LookupResult initializers = lookUpMember(
m_astBuilder,
@@ -2404,7 +2404,7 @@ namespace Slang
return argsListBuilder.produceString();
}
- Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr * expr)
+ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
{
OverloadResolveContext context;
// check if this is a stdlib operator call, if so we want to use cached results
@@ -2470,7 +2470,7 @@ namespace Slang
context.sourceScope = m_outerScope;
context.baseExpr = GetBaseExpr(funcExpr);
- // TODO: We should have a special case here where an `InvokeExpr`
+ // We run a special case here where an `InvokeExpr`
// with a single argument where the base/func expression names
// a type should always be treated as an explicit type coercion
// (and hence bottleneck through `coerce()`) instead of just
@@ -2484,8 +2484,33 @@ namespace Slang
// that `(T) expr` and `T(expr)` continue to be semantically
// `visitTypeCastExpr`) would allow us to continue to ensure
// equivalent in (almost) all cases.
+ // If callee is a type, and we are calling with one argument, then treat it as a
+ // type coercion.
+ bool typeOverloadChecked = false;
- if (!context.bestCandidate)
+ if (expr->arguments.getCount() == 1)
+ {
+ if (const auto typeType = as<TypeType>(funcExpr->type))
+ {
+ if (isDeclRefTypeOf<AggTypeDeclBase>(typeType->getType()))
+ {
+ Expr* resultExpr = nullptr;
+ DiagnosticSink tempSink(getSourceManager(), nullptr);
+ ConversionCost conversionCost = kConversionCost_None;
+ auto coerceResult = SemanticsVisitor(withSink(&tempSink))._coerce(
+ CoercionSite::ExplicitCoercion,
+ typeType->getType(),
+ &resultExpr,
+ expr->arguments[0]->type,
+ expr->arguments[0],
+ &conversionCost);
+ if (coerceResult)
+ return resultExpr;
+ typeOverloadChecked = true;
+ }
+ }
+ }
+ if (!context.bestCandidate && !typeOverloadChecked)
{
AddOverloadCandidates(funcExpr, context);
}
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index ae817f867..8b0e0b284 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -294,23 +294,21 @@ namespace Slang
void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt)
{
- auto expr = CheckExpr(stmt->expr);
-
- // coerce to type being switch on, and ensure that value is a compile-time constant
- // The Vals in the AST are pointer-unique, making them easy to check for duplicates
- // by addeing them to a HashSet.
- auto exprVal = tryConstantFoldExpr(expr, ConstantFoldingKind::CompileTime, nullptr);
auto switchStmt = FindOuterStmt<SwitchStmt>();
-
if (!switchStmt)
{
getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch);
+ return;
}
- else
- {
- // TODO: need to do some basic matching to ensure the type
- // for the `case` is consistent with the type for the `switch`...
- }
+
+ // Check that the type for the `case` is consistent with the type for the `switch`.
+ auto expr = CheckExpr(stmt->expr);
+ expr = coerce(CoercionSite::Argument, switchStmt->condition->type, expr);
+
+ // coerce to type being switch on, and ensure that value is a compile-time constant
+ // The Vals in the AST are pointer-unique, making them easy to check for duplicates
+ // by addeing them to a HashSet.
+ auto exprVal = checkConstantIntVal(expr);
stmt->expr = expr;
stmt->exprVal = exprVal;
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index dd5cb88d3..a2795675d 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -3232,7 +3232,7 @@ void CLikeSourceEmitter::emitLayoutSemantics(IRInst* inst, char const* uniformSe
emitLayoutSemanticsImpl(inst, uniformSemanticSpelling, EmitLayoutSemanticOption::kPostType);
}
-void CLikeSourceEmitter::emitSwitchCaseSelectorsImpl(IRBasicType *const /* switchCondition */, const SwitchRegion::Case *const currentCase, const bool isDefault)
+void CLikeSourceEmitter::emitSwitchCaseSelectorsImpl(const SwitchRegion::Case* currentCase, bool isDefault)
{
for(auto caseVal : currentCase->values)
{
@@ -3401,9 +3401,8 @@ void CLikeSourceEmitter::emitRegion(Region* inRegion)
auto defaultCase = switchRegion->defaultCase;
for(auto currentCase : switchRegion->cases)
{
- const bool isDefault {currentCase.Ptr() == defaultCase};
- IRBasicType *const switchConditionType {as<IRBasicType>(switchRegion->getCondition()->getDataType())};
- emitSwitchCaseSelectors(switchConditionType, currentCase.Ptr(), isDefault);
+ bool isDefault = (currentCase.Ptr() == defaultCase);
+ emitSwitchCaseSelectors(currentCase.Ptr(), isDefault);
m_writer->indent();
m_writer->emit("{\n");
m_writer->indent();
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index f0d703b40..41fb21fe8 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -371,7 +371,7 @@ public:
void emitFuncHeader(IRFunc* func) { emitFuncHeaderImpl(func); }
void emitSimpleFunc(IRFunc* func) { emitSimpleFuncImpl(func); }
- void emitSwitchCaseSelectors(IRBasicType *const switchConditionType, const SwitchRegion::Case *const currentCase, const bool isDefault) {emitSwitchCaseSelectorsImpl(switchConditionType, currentCase, isDefault);}
+ void emitSwitchCaseSelectors(const SwitchRegion::Case* currentCase, bool isDefault) {emitSwitchCaseSelectorsImpl(currentCase, isDefault);}
void emitParamType(IRType* type, String const& name) { emitParamTypeImpl(type, name); }
@@ -524,7 +524,7 @@ public:
virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) { SLANG_UNUSED(decl); }
virtual void emitIfDecorationsImpl(IRIfElse* ifInst) { SLANG_UNUSED(ifInst); }
virtual void emitSwitchDecorationsImpl(IRSwitch* switchInst) { SLANG_UNUSED(switchInst); }
- virtual void emitSwitchCaseSelectorsImpl(IRBasicType *const switchConditionType, const SwitchRegion::Case *const currentCase, const bool isDefault);
+ virtual void emitSwitchCaseSelectorsImpl(const SwitchRegion::Case* currentCase, bool isDefault);
virtual void emitFuncDecorationImpl(IRDecoration* decoration) { SLANG_UNUSED(decoration); }
virtual void emitLivenessImpl(IRInst* inst);
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
index b1a723dc5..105f5c3cf 100644
--- a/source/slang/slang-emit-wgsl.cpp
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -27,7 +27,6 @@ namespace Slang
{
void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl(
- IRBasicType *const switchConditionType,
const SwitchRegion::Case *const currentCase,
const bool isDefault)
{
@@ -38,39 +37,7 @@ void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl(
m_writer->emit("case ");
for (auto caseVal : currentCase->values)
{
- // TODO: Fix this in the front-end [1], remove the if-path and just do the else-path.
- // We can't do that at the moment because it would break Falcor [2].
- // [1] https://github.com/shader-slang/slang/pull/5025/commits/a32156ef52f43b8503b2c77f2f1d51220ab9bdea
- // [2] https://github.com/shader-slang/slang/pull/5025#issuecomment-2334495120
- if (caseVal->getOp() == kIROp_IntLit)
- {
- auto caseLitInst = static_cast<IRConstant*>(caseVal);
- IRBasicType *const caseInstType = as<IRBasicType>(caseLitInst->getDataType());
- // WGSL doesn't allow switch condition and case type mismatches, see [1].
- // Thus we need to insert explicit conversions.
- // Doing a wrapping cast will match Slang's de facto semantics, according to
- // [2].
- // (This is just a bitcast, assuming a two's complement representation.)
- // [1] https://www.w3.org/TR/WGSL/#switch-statement
- // [2] https://github.com/shader-slang/slang/issues/4921
- const bool needBitcast =
- caseInstType->getBaseType() != switchConditionType->getBaseType();
- if (needBitcast)
- {
- m_writer->emit("bitcast<");
- emitType(switchConditionType);
- m_writer->emit(">(");
- }
- emitOperand(caseVal, getInfo(EmitOp::General));
- if (needBitcast)
- {
- m_writer->emit(")");
- }
- }
- else
- {
- emitOperand(caseVal, getInfo(EmitOp::General));
- }
+ emitOperand(caseVal, getInfo(EmitOp::General));
m_writer->emit(", ");
}
if (isDefault)
diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h
index 0b4b04b12..703310d1e 100644
--- a/source/slang/slang-emit-wgsl.h
+++ b/source/slang/slang-emit-wgsl.h
@@ -21,10 +21,7 @@ public:
virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE;
virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE;
virtual bool tryEmitInstStmtImpl(IRInst* inst) SLANG_OVERRIDE;
- virtual void emitSwitchCaseSelectorsImpl(
- IRBasicType *const switchCondition,
- const SwitchRegion::Case *const currentCase,
- const bool isDefault) SLANG_OVERRIDE;
+ virtual void emitSwitchCaseSelectorsImpl(const SwitchRegion::Case* currentCase, bool isDefault) SLANG_OVERRIDE;
virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE;
virtual void emitVarKeywordImpl(IRType * type, IRInst* varDecl) SLANG_OVERRIDE;
virtual void emitDeclaratorImpl(DeclaratorInfo* declarator) SLANG_OVERRIDE;
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index 06b7c937f..537218769 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -144,6 +144,11 @@ public:
return dispatchIfNotNull(expr->base);
}
+ bool visitBuiltinCastExpr(BuiltinCastExpr* expr)
+ {
+ return dispatchIfNotNull(expr->base);
+ }
+
bool visitAssignExpr(AssignExpr* expr)
{
if (dispatchIfNotNull(expr->left))
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index b37df4b41..f5ec90d03 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4936,6 +4936,13 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
return sharedLoweringContext.visitInvokeExprImpl(expr, LoweredValInfo(), TryClauseEnvironment());
}
+ LoweredValInfo visitBuiltinCastExpr(BuiltinCastExpr* expr)
+ {
+ auto irType = lowerType(context, expr->type);
+ auto irVal = getSimpleVal(context, lowerRValueExpr(context, expr->base));
+ return LoweredValInfo::simple(context->irBuilder->emitCast(irType, irVal));
+ }
+
/// Emit code for a `try` invoke.
LoweredValInfo visitTryExpr(TryExpr* expr)
{
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index ba7909aae..207a8744d 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -247,6 +247,11 @@ namespace Slang
return substituteExpr(expr.getSubsts(), expr.getExpr()->base);
}
+ inline SubstExpr<Expr> getBaseExpr(SubstExpr<BuiltinCastExpr> expr)
+ {
+ return substituteExpr(expr.getSubsts(), expr.getExpr()->base);
+ }
+
inline SubstExpr<Expr> getBaseExpr(SubstExpr<InvokeExpr> expr)
{
return substituteExpr(expr.getSubsts(), expr.getExpr()->functionExpr);