summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/user-guide/06-interfaces-generics.md17
-rw-r--r--source/slang/slang-check-expr.cpp42
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-emit.cpp2
-rw-r--r--source/slang/slang-ir-peephole.cpp10
-rw-r--r--source/slang/slang-ir.cpp5
-rw-r--r--source/slang/slang-lower-to-ir.cpp52
-rw-r--r--tests/language-feature/types/as-is-generic.slang33
8 files changed, 121 insertions, 41 deletions
diff --git a/docs/user-guide/06-interfaces-generics.md b/docs/user-guide/06-interfaces-generics.md
index 59a0352fa..2436dab07 100644
--- a/docs/user-guide/06-interfaces-generics.md
+++ b/docs/user-guide/06-interfaces-generics.md
@@ -671,6 +671,23 @@ void main()
// "success"
```
+In addition to casting from an interface type to a concrete type, `as` and `is` operator can be used on generic types as well to cast a generic type into a concrete type. For example:
+```csharp
+T compute<T>(T a1, T a2)
+{
+ if (a1 is float)
+ {
+ return reinterpret<T>((a1 as float).value + (a2 as float).value);
+ }
+ else if (T is int)
+ {
+ return reinterpret<T>((a1 as int).value - (a2 as int).value);
+ }
+ return T();
+}
+// compute(1.0f, 2.0f) == 3.0f
+// compute(3, 1) == 2
+```
Extensions to Interfaces
-----------------------------
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 89a7373ee..33a1fa680 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -3183,8 +3183,12 @@ namespace Slang
expr->type = m_astBuilder->getBoolType();
expr->value = originalVal;
+ auto valueType = expr->value->type.type;
+ if (auto typeType = as<TypeType>(valueType))
+ valueType = typeType->getType();
+
// If value is a subtype of `type`, then this expr is always true.
- if(isSubtype(expr->value->type.type, expr->typeExpr.type))
+ if(isSubtype(valueType, expr->typeExpr.type))
{
// Instead of returning a BoolLiteralExpr, we use a field to indicate this scenario,
// so that the language server can still see the original syntax tree.
@@ -3195,10 +3199,11 @@ namespace Slang
return expr;
}
- // Otherwise, we need to ensure the target type is a subtype of value->type.
+ // Otherwise, if the target type is a subtype of value->type, we need to grab the
+ // subtype witness for runtime checks.
expr->value = maybeOpenExistential(originalVal);
- expr->witnessArg = tryGetSubtypeWitness(expr->typeExpr.type, originalVal->type.type);
+ expr->witnessArg = tryGetSubtypeWitness(expr->typeExpr.type, valueType);
if (expr->witnessArg)
{
// For now we can only support the scenario where `expr->value` is an interface type.
@@ -3208,15 +3213,6 @@ namespace Slang
}
return expr;
}
-
- if (!as<ErrorType>(expr->typeExpr.type) && !as<ErrorType>(expr->value->type.type))
- {
- // The type is not in the same hierarchy, so we evaluate to false.
- expr->constantVal = m_astBuilder->create<BoolLiteralExpr>();
- expr->constantVal->type = m_astBuilder->getBoolType();
- expr->constantVal->value = false;
- expr->constantVal->loc = expr->loc;
- }
return expr;
}
@@ -3241,27 +3237,21 @@ namespace Slang
return makeOptional;
}
- // For now we can only support the scenario where `expr->value` is an interface type.
- if (!isInterfaceType(expr->value->type))
- {
- getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType);
- }
-
- expr->typeExpr = typeExpr.exp;
+ // If target type is an interface type, we will obtain the witness here for
+ // runtime casting.
expr->witnessArg = tryGetSubtypeWitness(typeExpr.type, expr->value->type.type);
if (expr->witnessArg)
{
+ // For now we can only support the scenario where `expr->value` is an interface type.
+ if (!isInterfaceType(expr->value->type.type))
+ {
+ getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType);
+ }
expr->value = maybeOpenExistential(expr->value);
return expr;
}
- if (!as<ErrorType>(typeExpr.type) && !as<ErrorType>(expr->value->type.type))
- {
- getSink()->diagnose(expr, Diagnostics::typeNotInTheSameHierarchy, expr->value->type.type, typeExpr.type);
- }
-
- expr->type = m_astBuilder->getErrorType();
-
+ expr->typeExpr = typeExpr.exp;
return expr;
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 037e441d4..2150be5ad 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -282,7 +282,6 @@ DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload foun
DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'")
DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.")
DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.")
-DIAGNOSTIC(30018, Error, typeNotInTheSameHierarchy, "invalid use of 'as' operator: expression evaluates to '$0', which is not in the same type hierarchy as target type '$1'.")
DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'")
DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)")
DIAGNOSTIC(30022, Error, invalidTypeCast, "invalid type cast between \"$0\" and \"$1\".")
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 763d96840..2217bc143 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -478,6 +478,8 @@ Result linkAndOptimizeIR(
}
lowerReinterpret(targetProgram, irModule, sink);
+ if (sink->getErrorCount() != 0)
+ return SLANG_FAIL;
validateIRModuleIfEnabled(codeGenContext, irModule);
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index d7c3c6bda..d4369da7a 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -958,8 +958,14 @@ struct PeepholeContext : InstPassBase
}
case kIROp_TypeEquals:
{
- auto left = inst->getOperand(0)->getDataType();
- auto right = inst->getOperand(1)->getDataType();
+ auto getTypeFromOperand = [](IRInst* operand) -> IRType*
+ {
+ if (as<IRTypeType>(operand->getFullType()) || !operand->getFullType())
+ return (IRType*)operand;
+ return operand->getFullType();
+ };
+ auto left = getTypeFromOperand(inst->getOperand(0));
+ auto right = getTypeFromOperand(inst->getOperand(1));
if (isConcreteType(left) && isConcreteType(right))
{
IRBuilder builder(module);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index ff9c6f6f8..fd5ea0fc7 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -31,6 +31,11 @@ namespace Slang
{
if (!irObject)
return;
+ if (as<IRType>(irObject))
+ {
+ getTypeNameHint(sb, irObject);
+ return;
+ }
if (auto nameHint = irObject->findDecoration<IRNameHintDecoration>())
{
sb << nameHint->getName();
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 409cd65ee..638a9c577 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4314,9 +4314,7 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
getBuilder()->emitMakeStruct(irType, args.getCount(), args.getBuffer()));
}
}
-
- SLANG_UNEXPECTED("unexpected type when creating default value");
- UNREACHABLE_RETURN(LoweredValInfo());
+ return LoweredValInfo::simple(getBuilder()->emitDefaultConstruct(irType));
}
LoweredValInfo getDefaultVal(DeclRef<VarDeclBase> decl)
@@ -4718,20 +4716,32 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
LoweredValInfo visitAsTypeExpr(AsTypeExpr* expr)
{
auto value = lowerLValueExpr(context, expr->value);
- auto existentialInfo = value.getExtractedExistentialValInfo();
+ ExtractedExistentialValInfo* existentialInfo = nullptr;
auto optType = lowerType(context, expr->type);
SLANG_RELEASE_ASSERT(optType->getOp() == kIROp_OptionalType);
auto targetType = optType->getOperand(0);
- auto witness = lowerSimpleVal(context, expr->witnessArg);
auto builder = getBuilder();
auto var = builder->emitVar(optType);
- auto isType = builder->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, targetType, witness);
+ IRInst* isType = nullptr;
+ if (expr->witnessArg)
+ {
+ auto witness = lowerSimpleVal(context, expr->witnessArg);
+ existentialInfo = value.getExtractedExistentialValInfo();
+ isType = builder->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, targetType, witness);
+ }
+ else
+ {
+ SLANG_ASSERT(value.val);
+ auto leftType = lowerType(context, expr->value->type);
+ IRInst* args[] = { leftType, targetType };
+ isType = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_TypeEquals, 2, args);
+ }
IRBlock* trueBlock;
IRBlock* falseBlock;
IRBlock* afterBlock;
builder->emitIfElseWithBlocks(isType, trueBlock, falseBlock, afterBlock);
builder->setInsertInto(trueBlock);
- auto irVal = builder->emitReinterpret(targetType, existentialInfo->extractedVal);
+ auto irVal = builder->emitReinterpret(targetType, existentialInfo ? existentialInfo->extractedVal : getSimpleVal(context, value));
auto optionalVal = builder->emitMakeOptionalValue(optType, irVal);
builder->emitStore(var, optionalVal);
builder->emitBranch(afterBlock);
@@ -4751,11 +4761,29 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
{
return LoweredValInfo::simple(getBuilder()->getBoolValue(expr->constantVal->value));
}
- auto value = lowerLValueExpr(context, expr->value);
- auto type = lowerType(context, expr->type);
- auto witness = lowerSimpleVal(context, expr->witnessArg);
- auto existentialInfo = value.getExtractedExistentialValInfo();
- auto irVal = getBuilder()->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, type, witness);
+ // If expr is a witness, then this is a run-time type check from for an existential type.
+ if (expr->witnessArg)
+ {
+ auto value = lowerLValueExpr(context, expr->value);
+ auto type = lowerType(context, expr->typeExpr.type);
+ auto witness = lowerSimpleVal(context, expr->witnessArg);
+ auto existentialInfo = value.getExtractedExistentialValInfo();
+ auto irVal = getBuilder()->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, type, witness);
+ return LoweredValInfo::simple(irVal);
+ }
+ // For all other cases, we map to a simple type equality check in the IR.
+ IRType* leftType = nullptr;
+ if (auto typeType = as<TypeType>(expr->value->type))
+ {
+ leftType = lowerType(context, typeType->getType());
+ }
+ else
+ {
+ leftType = lowerType(context, expr->value->type);
+ }
+ auto rightType = lowerType(context, expr->typeExpr.type);
+ IRInst* args[] = { leftType, rightType };
+ auto irVal = getBuilder()->emitIntrinsicInst(getBuilder()->getBoolType(), kIROp_TypeEquals, 2, args);
return LoweredValInfo::simple(irVal);
}
diff --git a/tests/language-feature/types/as-is-generic.slang b/tests/language-feature/types/as-is-generic.slang
new file mode 100644
index 000000000..6790eda37
--- /dev/null
+++ b/tests/language-feature/types/as-is-generic.slang
@@ -0,0 +1,33 @@
+// optional.slang
+
+// Test that `is` and `as` operator works on generic types.
+
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -vk -compute -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -cpu -compute -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+T compute<T>(T a1, T a2)
+{
+ if (a1 is float)
+ {
+ return reinterpret<T>((a1 as float).value + (a2 as float).value);
+ }
+ else if (T is int)
+ {
+ return reinterpret<T>((a1 as int).value - (a2 as int).value);
+ }
+ return T();
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ // CHECK: 3.0
+ outputBuffer[0] = compute(1.0f, 2.0f);
+
+ // CHECK: 1.0
+ outputBuffer[1] = compute(2, 1);
+}