summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-08-10 14:11:27 -0700
committerGitHub <noreply@github.com>2022-08-10 14:11:27 -0700
commit88f04c29244af23c1cdd472d8d1ae3e5a650494e (patch)
tree398e55440e8f7ad157d15b2b75d9887236eaa126
parentfcdb4629c4c3dd2931eaa88b96b668d914c4519c (diff)
`is` and `as` operator and `Optional<T>`. (#2355)
* `is` and `as` operator and `Optional<T>`. * Fix. Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--build/visual-studio/slang/slang.vcxproj2
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters6
-rw-r--r--source/slang/core.meta.slang19
-rw-r--r--source/slang/slang-ast-builder.cpp25
-rw-r--r--source/slang/slang-ast-builder.h9
-rw-r--r--source/slang/slang-ast-expr.h38
-rw-r--r--source/slang/slang-ast-iterator.h18
-rw-r--r--source/slang/slang-ast-type.cpp5
-rw-r--r--source/slang/slang-ast-type.h6
-rw-r--r--source/slang/slang-check-conformance.cpp22
-rw-r--r--source/slang/slang-check-expr.cpp84
-rw-r--r--source/slang/slang-check-impl.h12
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-emit.cpp4
-rw-r--r--source/slang/slang-ir-any-value-marshalling.cpp7
-rw-r--r--source/slang/slang-ir-inst-defs.h7
-rw-r--r--source/slang/slang-ir-insts.h48
-rw-r--r--source/slang/slang-ir-lower-generics.cpp31
-rw-r--r--source/slang/slang-ir-lower-optional-type.cpp239
-rw-r--r--source/slang/slang-ir-lower-optional-type.h16
-rw-r--r--source/slang/slang-ir-peephole.cpp82
-rw-r--r--source/slang/slang-ir-specialize-dispatch.cpp2
-rw-r--r--source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp28
-rw-r--r--source/slang/slang-ir.cpp63
-rw-r--r--source/slang/slang-ir.h8
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp18
-rw-r--r--source/slang/slang-language-server-completion.cpp3
-rw-r--r--source/slang/slang-lower-to-ir.cpp62
-rw-r--r--source/slang/slang-parser.cpp41
-rw-r--r--tests/language-feature/interfaces/is-as-dynamic.slang48
-rw-r--r--tests/language-feature/interfaces/is-as-dynamic.slang.expected.txt1
-rw-r--r--tests/language-feature/interfaces/is-as.slang46
-rw-r--r--tests/language-feature/interfaces/is-as.slang.expected.txt1
33 files changed, 950 insertions, 53 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index 5f8c0bd2d..7e4144240 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -382,6 +382,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-generic-function.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-generic-type.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-generics.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-optional-type.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-reinterpret.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-result-type.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-tuple-types.h" />
@@ -543,6 +544,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-generic-function.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-generic-type.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-generics.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-optional-type.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-reinterpret.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-result-type.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-tuple-types.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 18bebd332..27589ffee 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -243,6 +243,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-generics.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-optional-type.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-reinterpret.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -722,6 +725,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-generics.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-optional-type.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-reinterpret.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 41f066486..616511b01 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -456,6 +456,25 @@ __intrinsic_type($(kIROp_RefType))
struct Ref
{};
+__generic<T>
+__magic_type(OptionalType)
+__intrinsic_type($(kIROp_OptionalType))
+struct Optional
+{
+ property bool hasValue
+ {
+ __intrinsic_op($(kIROp_OptionalHasValue))
+ get;
+ }
+
+ property T value
+ {
+ __intrinsic_op($(kIROp_GetOptionalValue))
+ get;
+ }
+};
+
+
__magic_type(StringType)
__intrinsic_type($(kIROp_StringType))
struct String
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 868763f76..2acdc26d3 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -212,6 +212,13 @@ NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType)
return (NodeBase*)createFunc(this);
}
+Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTypeName)
+{
+ auto declRef = getBuiltinDeclRef(magicTypeName, makeConstArrayViewSingle<Val*>(typeParam));
+ auto rsType = DeclRefType::create(this, declRef);
+ return rsType;
+}
+
PtrType* ASTBuilder::getPtrType(Type* valueType)
{
return dynamicCast<PtrType>(getPtrType(valueType, "PtrType"));
@@ -233,23 +240,15 @@ RefType* ASTBuilder::getRefType(Type* valueType)
return dynamicCast<RefType>(getPtrType(valueType, "RefType"));
}
-PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName)
+OptionalType* ASTBuilder::getOptionalType(Type* valueType)
{
- auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl(ptrTypeName));
- return getPtrType(valueType, genericDecl);
+ auto rsType = getSpecializedBuiltinType(valueType, "OptionalType");
+ return as<OptionalType>(rsType);
}
-PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, GenericDecl* genericDecl)
+PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName)
{
- auto typeDecl = genericDecl->inner;
-
- auto substitutions = create<GenericSubstitution>();
- substitutions->genericDecl = genericDecl;
- substitutions->args.add(valueType);
-
- auto declRef = DeclRef<Decl>(typeDecl, substitutions);
- auto rsType = DeclRefType::create(this, declRef);
- return as<PtrTypeBase>(rsType);
+ return as<PtrTypeBase>(getSpecializedBuiltinType(valueType, ptrTypeName));
}
ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount)
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index 3c0303e70..e62e92a7b 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -131,6 +131,8 @@ public:
/// Get a builtin type by the BaseType
SLANG_FORCE_INLINE Type* getBuiltinType(BaseType flavor) { return m_sharedASTBuilder->m_builtinTypes[Index(flavor)]; }
+ Type* getSpecializedBuiltinType(Type* typeParam, const char* magicTypeName);
+
Type* getInitializerListType() { return m_sharedASTBuilder->m_initializerListType; }
Type* getOverloadedType() { return m_sharedASTBuilder->m_overloadedType; }
Type* getErrorType() { return m_sharedASTBuilder->m_errorType; }
@@ -152,14 +154,13 @@ public:
// Construct the type `Ref<valueType>`
RefType* getRefType(Type* valueType);
+ // Construct the type `Optional<valueType>`
+ OptionalType* getOptionalType(Type* valueType);
+
// Construct a pointer type like `Ptr<valueType>`, but where
// the actual type name for the pointer type is given by `ptrTypeName`
PtrTypeBase* getPtrType(Type* valueType, char const* ptrTypeName);
- // Construct a pointer type like `Ptr<valueType>`, but where
- // the generic declaration for the pointer type is `genericDecl`
- PtrTypeBase* getPtrType(Type* valueType, GenericDecl* genericDecl);
-
ArrayExpressionType* getArrayType(Type* elementType, IntVal* elementCount);
VectorExpressionType* getVectorType(Type* elementType, IntVal* elementCount);
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 4a9cc475d..2dfd937a4 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -299,6 +299,44 @@ class CastToSuperTypeExpr: public Expr
Val* witnessArg = nullptr;
};
+ /// A `value is Type` expression that evaluates to `true` if type of `value` is a sub-type of
+ /// `Type`.
+class IsTypeExpr : public Expr
+{
+ SLANG_AST_CLASS(IsTypeExpr)
+
+ Expr* value = nullptr;
+ TypeExp typeExpr;
+
+ // A witness showing that `typeExpr.type` is a subtype of `typeof(value)`.
+ Val* witnessArg = nullptr;
+
+ bool isAlwaysTrue = false;
+};
+
+ /// A `value as Type` expression that casts `value` to `Type` within type hierarchy.
+ /// The result is undefined if `value` is not `Type`.
+class AsTypeExpr : public Expr
+{
+ SLANG_AST_CLASS(AsTypeExpr)
+
+ Expr* value = nullptr;
+ Expr* typeExpr = nullptr;
+
+ // A witness showing that `typeExpr` is a subtype of `typeof(value)`.
+ Val* witnessArg = nullptr;
+
+};
+
+class MakeOptionalExpr : public Expr
+{
+ SLANG_AST_CLASS(MakeOptionalExpr)
+
+ // If `value` is null, this constructs an `Optional<T>` that doesn't have a value.
+ Expr* value = nullptr;
+ Expr* typeExpr = nullptr;
+};
+
/// A cast of a value to the same type, with different modifiers.
///
/// The type being cast to is stored as this expression's `type`.
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index d439420a9..8461ff7a3 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -235,6 +235,24 @@ struct ASTIterator
iterator->maybeDispatchCallback(expr);
dispatchIfNotNull(expr->base.exp);
}
+ void visitAsTypeExpr(AsTypeExpr* expr)
+ {
+ iterator->maybeDispatchCallback(expr);
+ dispatchIfNotNull(expr->value);
+ dispatchIfNotNull(expr->typeExpr);
+ }
+ void visitIsTypeExpr(IsTypeExpr* expr)
+ {
+ iterator->maybeDispatchCallback(expr);
+ dispatchIfNotNull(expr->value);
+ dispatchIfNotNull(expr->typeExpr.exp);
+ }
+ void visitMakeOptionalExpr(MakeOptionalExpr* expr)
+ {
+ iterator->maybeDispatchCallback(expr);
+ dispatchIfNotNull(expr->value);
+ dispatchIfNotNull(expr->typeExpr);
+ }
};
struct ASTIteratorStmtVisitor : public StmtVisitor<ASTIteratorStmtVisitor>
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 43fe751ee..664c940a8 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -517,6 +517,11 @@ Type* PtrTypeBase::getValueType()
return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]);
}
+Type* OptionalType::getValueType()
+{
+ return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]);
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void NamedExpressionType::_toTextOverride(StringBuilder& out)
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 5d4e42bfb..7b94cbe6d 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -572,6 +572,12 @@ class RefType : public ParamDirectionType
SLANG_AST_CLASS(RefType)
};
+class OptionalType : public BuiltinType
+{
+ SLANG_AST_CLASS(OptionalType)
+ Type* getValueType();
+};
+
// A type alias of some kind (e.g., via `typedef`)
class NamedExpressionType : public Type
{
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index e77ea4981..5889c6140 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -390,6 +390,28 @@ namespace Slang
return _isDeclaredSubtype(subType, subType, superTypeDeclRef, nullptr, nullptr);
}
+ bool SemanticsVisitor::isDeclaredSubtype(
+ Type* subType,
+ Type* superType)
+ {
+ if (auto declRefType = as<DeclRefType>(superType))
+ {
+ if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>())
+ return _isDeclaredSubtype(subType, subType, aggTypeDeclRef, nullptr, nullptr);
+ }
+ return false;
+ }
+
+ bool SemanticsVisitor::isInterfaceType(Type* type)
+ {
+ if (auto declRefType = as<DeclRefType>(type))
+ {
+ if (auto interfaceDeclRef = declRefType->declRef.as<InterfaceDecl>())
+ return true;
+ }
+ return false;
+ }
+
Val* SemanticsVisitor::tryGetSubtypeWitness(
Type* subType,
DeclRef<AggTypeDecl> superTypeDeclRef)
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index a787af211..714eba9a3 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1753,6 +1753,90 @@ namespace Slang
return expr;
}
+ Expr* SemanticsExprVisitor::visitIsTypeExpr(IsTypeExpr* expr)
+ {
+ expr->typeExpr = CheckProperType(expr->typeExpr);
+ auto originalVal = CheckTerm(expr->value);
+ expr->type = m_astBuilder->getBoolType();
+ expr->value = originalVal;
+
+ // If value is a subtype of `type`, then this expr is always true.
+ if (isDeclaredSubtype(expr->value->type.type, expr->typeExpr.type))
+ {
+ // Instead of returning a BoolLiteralExpr, we use a field to indicate this scenario,
+ // so that the language server can still see the original syntax tree.
+ expr->isAlwaysTrue = true;
+ return expr;
+ }
+
+ // Otherwise, we need to ensure the target type is a subtype of value->type.
+ // For now we can only support the scenario where `expr->value` is an interface type.
+ if (!isInterfaceType(originalVal->type))
+ {
+ getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType);
+ }
+
+ expr->value = maybeOpenExistential(originalVal);
+ expr->witnessArg = tryGetSubtypeWitness(expr->typeExpr.type, originalVal->type.type);
+ if (expr->witnessArg)
+ {
+ return expr;
+ }
+
+ if (!as<ErrorType>(expr->typeExpr.type) && !as<ErrorType>(expr->value->type.type))
+ {
+ getSink()->diagnose(expr, Diagnostics::typeNotInTheSameHierarchy, expr->value->type.type, expr->typeExpr.type);
+ }
+
+ expr->type = m_astBuilder->getErrorType();
+ return expr;
+ }
+
+ Expr* SemanticsExprVisitor::visitAsTypeExpr(AsTypeExpr* expr)
+ {
+ TypeExp typeExpr;
+ typeExpr.exp = expr->typeExpr;
+ typeExpr = CheckProperType(typeExpr);
+ expr->value = CheckTerm(expr->value);
+ auto optType = m_astBuilder->getOptionalType(typeExpr.type);
+ expr->type = optType;
+
+ // If value is a subtype of `type`, then this expr is equivalent to a CastToSuperTypeExpr.
+ if (auto witness = tryGetSubtypeWitness(expr->value->type.type, typeExpr.type))
+ {
+ auto castToSuperType = createCastToSuperTypeExpr(typeExpr.type, expr->value, witness);
+ auto makeOptional = m_astBuilder->create<MakeOptionalExpr>();
+ makeOptional->loc = expr->loc;
+ makeOptional->type = optType;
+ makeOptional->value = castToSuperType;
+ makeOptional->typeExpr = typeExpr.exp;
+ return makeOptional;
+ }
+
+ // For now we can only support the scenario where `expr->value` is an interface type.
+ if (!isInterfaceType(expr->value->type))
+ {
+ getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType);
+ }
+
+ expr->typeExpr = typeExpr.exp;
+ expr->witnessArg = tryGetSubtypeWitness(typeExpr.type, expr->value->type.type);
+ if (expr->witnessArg)
+ {
+ expr->value = maybeOpenExistential(expr->value);
+ return expr;
+ }
+
+ if (!as<ErrorType>(typeExpr.type) && !as<ErrorType>(expr->value->type.type))
+ {
+ getSink()->diagnose(expr, Diagnostics::typeNotInTheSameHierarchy, expr->value->type.type, typeExpr.type);
+ }
+
+ expr->type = m_astBuilder->getErrorType();
+
+ return expr;
+ }
+
Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr)
{
Expr* expr = inExpr;
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index f2f7a6bd1..df713d80f 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1278,6 +1278,13 @@ namespace Slang
Type* subType,
DeclRef<AggTypeDecl> superTypeDeclRef);
+ /// Check whether `subType` is a sub-type of `supType`.
+ bool isDeclaredSubtype(
+ Type* subType,
+ Type* supType);
+
+ bool isInterfaceType(Type* type);
+
/// Check whether `subType` is a sub-type of `superTypeDeclRef`,
/// and return a witness to the sub-type relationship if it holds
/// (return null otherwise).
@@ -1714,6 +1721,10 @@ namespace Slang
Expr* visitTryExpr(TryExpr* expr);
+ Expr* visitIsTypeExpr(IsTypeExpr* expr);
+
+ Expr* visitAsTypeExpr(AsTypeExpr* expr);
+
//
// Some syntax nodes should not occur in the concrete input syntax,
// and will only appear *after* checking is complete. We need to
@@ -1740,6 +1751,7 @@ namespace Slang
CASE(LetExpr)
CASE(ExtractExistentialValueExpr)
CASE(OpenRefExpr)
+ CASE(MakeOptionalExpr)
#undef CASE
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 2b12c3de4..54d81da7d 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -251,6 +251,7 @@ DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for t
DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.")
DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.")
DIAGNOSTIC(30017, Error, componentNotAccessibleFromShader, "component '$0' is not accessible from shader '$1'.")
+DIAGNOSTIC(30018, Error, typeNotInTheSameHierarchy, "as/is operator requires '$0' and '$1' to be in the same type hierarchy.")
DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'")
DIAGNOSTIC(30020, Error, importOperatorReturnTypeMismatch, "import operator should return '$1', but the expression has type '$0''. do you forget 'project'?")
DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)")
@@ -286,6 +287,7 @@ DIAGNOSTIC(30200, Error, redeclaration, "declaration of '$0' conflicts with exis
DIAGNOSTIC(30201, Error, functionRedefinition, "function '$0' already has a body")
DIAGNOSTIC(30202, Error, functionRedeclarationWithDifferentReturnType, "function '$0' declared to return '$1' was previously declared to return '$2'")
+DIAGNOSTIC(30300, Error, isOperatorValueMustBeInterfaceType, "'is'/'as' operator requires an interface-typed expression.")
DIAGNOSTIC(33070, Error, expectedFunction, "expected a function, got '$0'")
DIAGNOSTIC(33071, Error, expectedAStringLiteral, "expected a string literal")
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 21073f4c8..da8739a49 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -25,6 +25,7 @@
#include "slang-ir-lower-generics.h"
#include "slang-ir-lower-tuple-types.h"
#include "slang-ir-lower-result-type.h"
+#include "slang-ir-lower-optional-type.h"
#include "slang-ir-lower-bit-cast.h"
#include "slang-ir-lower-reinterpret.h"
#include "slang-ir-metadata.h"
@@ -387,6 +388,9 @@ Result linkAndOptimizeIR(
// will run a DCE pass to clean up after the specialization.
//
simplifyIR(irModule);
+
+ lowerOptionalType(irModule, sink);
+
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "AFTER DCE");
#endif
diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp
index 5c91766f7..39292e2b1 100644
--- a/source/slang/slang-ir-any-value-marshalling.cpp
+++ b/source/slang/slang-ir-any-value-marshalling.cpp
@@ -762,6 +762,13 @@ namespace Slang
size += kRTTIHeaderSize;
return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
}
+ case kIROp_ExtractExistentialType:
+ {
+ auto existentialValue = type->getOperand(0);
+ auto interfaceType = cast<IRInterfaceType>(existentialValue->getDataType());
+ auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
+ return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
+ }
default:
if (as<IRTextureTypeBase>(type) || as<IRSamplerStateTypeBase>(type))
{
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index aeb6d4ea1..978317ccd 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -57,6 +57,7 @@ INST(Nop, nop, 0, 0)
INST(ConjunctionType, Conjunction, 0, 0)
INST(AttributedType, Attributed, 0, 0)
INST(ResultType, Result, 2, 0)
+ INST(OptionalType, Optional, 1, 0)
INST(DifferentialPairType, DiffPair, 1, 0)
@@ -292,7 +293,10 @@ INST(MakeResultError, makeResultError, 1, 0)
INST(IsResultError, isResultError, 1, 0)
INST(GetResultError, getResultError, 1, 0)
INST(GetResultValue, getResultValue, 1, 0)
-
+INST(GetOptionalValue, getOptionalValue, 1, 0)
+INST(OptionalHasValue, optionalHasValue, 1, 0)
+INST(MakeOptionalValue, makeOptionalValue, 1, 0)
+INST(MakeOptionalNone, makeOptionalNone, 1, 0)
INST(Call, call, 1, 0)
INST(RTTIObject, rtti_object, 0, 0)
@@ -745,6 +749,7 @@ INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0)
INST(BitCast, bitCast, 1, 0)
INST(Reinterpret, reinterpret, 1, 0)
INST(CastPtrToBool, CastPtrToBool, 1, 0)
+INST(IsType, IsType, 3, 0)
INST(JVPDifferentiate, jvpDifferentiate, 1, 0)
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 2e2dbed5a..7f4e991b0 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1902,6 +1902,33 @@ struct IRGetResultError : IRInst
IRInst* getResultOperand() { return getOperand(0); }
};
+struct IROptionalHasValue : IRInst
+{
+ IR_LEAF_ISA(OptionalHasValue)
+
+ IRInst* getOptionalOperand() { return getOperand(0); }
+};
+
+struct IRGetOptionalValue : IRInst
+{
+ IR_LEAF_ISA(GetOptionalValue)
+
+ IRInst* getOptionalOperand() { return getOperand(0); }
+};
+
+struct IRMakeOptionalValue : IRInst
+{
+ IR_LEAF_ISA(MakeOptionalValue)
+
+ IRInst* getValue() { return getOperand(0); }
+};
+
+struct IRMakeOptionalNone : IRInst
+{
+ IR_LEAF_ISA(MakeOptionalNone)
+ IRInst* getDefaultValue() { return getOperand(0); }
+};
+
/// An instruction that packs a concrete value into an existential-type "box"
struct IRMakeExistential : IRInst
{
@@ -1985,6 +2012,17 @@ struct IRLiveRangeStart : IRLiveRangeMarker
IR_LEAF_ISA(LiveRangeStart);
};
+struct IRIsType : IRInst
+{
+ IR_LEAF_ISA(IsType);
+
+ IRInst* getValue() { return getOperand(0); }
+ IRInst* getValueWitness() { return getOperand(1); }
+
+ IRInst* getTypeOperand() { return getOperand(2); }
+ IRInst* getTargetWitness() { return getOperand(3); }
+};
+
/// Demarks where the referenced item is no longer live, optimimally (although not
/// necessarily) at the previous instruction.
///
@@ -2256,6 +2294,7 @@ public:
IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2, IRType* type3);
IRResultType* getResultType(IRType* valueType, IRType* errorType);
+ IROptionalType* getOptionalType(IRType* valueType);
IRBasicBlockType* getBasicBlockType();
IRWitnessTableType* getWitnessTableType(IRType* baseType);
@@ -2504,7 +2543,10 @@ public:
IRInst* emitIsResultError(IRInst* result);
IRInst* emitGetResultError(IRInst* result);
IRInst* emitGetResultValue(IRInst* result);
-
+ IRInst* emitOptionalHasValue(IRInst* optValue);
+ IRInst* emitGetOptionalValue(IRInst* optValue);
+ IRInst* emitMakeOptionalValue(IRInst* optType, IRInst* value);
+ IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue);
IRInst* emitMakeVector(
IRType* type,
UInt argCount,
@@ -2581,6 +2623,8 @@ public:
IRUndefined* emitUndefined(IRType* type);
+ IRInst* emitReinterpret(IRInst* type, IRInst* value);
+
IRInst* findOrAddInst(
IRType* type,
IROp op,
@@ -2717,6 +2761,8 @@ public:
IRInst* coord,
IRInst* value);
+ IRInst* emitIsType(IRInst* value, IRInst* witness, IRInst* typeOperand, IRInst* targetWitness);
+
IRInst* emitFieldExtract(
IRType* type,
IRInst* base,
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index 3b9a17738..c95f6976c 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -9,11 +9,12 @@
#include "slang-ir-lower-generic-function.h"
#include "slang-ir-lower-generic-call.h"
#include "slang-ir-lower-generic-type.h"
+#include "slang-ir-inst-pass-base.h"
#include "slang-ir-specialize-dispatch.h"
#include "slang-ir-specialize-dynamic-associatedtype-lookup.h"
#include "slang-ir-witness-table-wrapper.h"
-#include "slang-ir-ssa.h"
-#include "slang-ir-dce.h"
+#include "slang-ir-ssa-simplification.h"
+
namespace Slang
{
@@ -93,6 +94,23 @@ namespace Slang
inst->removeAndDeallocate();
}
}
+
+ void lowerIsTypeInsts(SharedGenericsLoweringContext* sharedContext)
+ {
+ InstPassBase pass(sharedContext->module);
+ pass.processInstsOfType<IRIsType>(kIROp_IsType, [&](IRIsType* inst)
+ {
+ auto witnessTableType = as<IRWitnessTableTypeBase>(inst->getValueWitness()->getDataType());
+ if (witnessTableType && isComInterfaceType((IRType*)witnessTableType->getConformanceType()))
+ return;
+ IRBuilder builder(sharedContext->sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ auto eqlInst = builder.emitEql(builder.emitGetSequentialIDInst(inst->getValueWitness()),
+ builder.emitGetSequentialIDInst(inst->getTargetWitness()));
+ inst->replaceUsesWith(eqlInst);
+ inst->removeAndDeallocate();
+ });
+ }
// Turn all references of witness table or RTTI objects into integer IDs, generate
// specialized `switch` based dispatch functions based on witness table IDs, and remove
@@ -105,6 +123,8 @@ namespace Slang
if (sink->getErrorCount() != 0)
return;
+ lowerIsTypeInsts(sharedContext);
+
specializeDynamicAssociatedTypeLookup(sharedContext);
if (sink->getErrorCount() != 0)
return;
@@ -112,6 +132,7 @@ namespace Slang
sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
sharedContext->mapInterfaceRequirementKeyValue.Clear();
+
specializeRTTIObjectReferences(sharedContext);
cleanUpRTTIHandleTypes(sharedContext);
@@ -175,6 +196,7 @@ namespace Slang
// and used to create a tuple representing the existential value.
augmentMakeExistentialInsts(module);
+
lowerGenericFunctions(&sharedContext);
if (sink->getErrorCount() != 0)
return;
@@ -200,6 +222,8 @@ namespace Slang
// real RTTI objects and witness tables.
specializeRTTIObjects(&sharedContext, sink);
+ simplifyIR(module);
+
lowerTuples(module, sink);
if (sink->getErrorCount() != 0)
return;
@@ -210,7 +234,6 @@ namespace Slang
// We might have generated new temporary variables during lowering.
// An SSA pass can clean up unnecessary load/stores.
- constructSSA(module);
- eliminateDeadCode(module);
+ simplifyIR(module);
}
} // namespace Slang
diff --git a/source/slang/slang-ir-lower-optional-type.cpp b/source/slang/slang-ir-lower-optional-type.cpp
new file mode 100644
index 000000000..79be4e042
--- /dev/null
+++ b/source/slang/slang-ir-lower-optional-type.cpp
@@ -0,0 +1,239 @@
+// slang-ir-lower-optional-type.cpp
+
+#include "slang-ir-lower-optional-type.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+ struct OptionalTypeLoweringContext
+ {
+ IRModule* module;
+ DiagnosticSink* sink;
+
+ SharedIRBuilder sharedBuilderStorage;
+
+ List<IRInst*> workList;
+ HashSet<IRInst*> workListSet;
+
+ struct LoweredOptionalTypeInfo : public RefObject
+ {
+ IRType* optionalType = nullptr;
+ IRType* valueType = nullptr;
+ IRType* loweredType = nullptr;
+ IRStructField* valueField = nullptr;
+ IRStructField* hasValueField = nullptr;
+ };
+ Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> mapLoweredTypeToOptionalTypeInfo;
+ Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> loweredOptionalTypes;
+
+ IRType* maybeLowerOptionalType(IRBuilder* builder, IRType* type)
+ {
+ if (auto info = getLoweredOptionalType(builder, type))
+ return info->loweredType;
+ else
+ return type;
+ }
+
+ LoweredOptionalTypeInfo* getLoweredOptionalType(IRBuilder* builder, IRInst* type)
+ {
+ if (auto loweredInfo = loweredOptionalTypes.TryGetValue(type))
+ return loweredInfo->Ptr();
+ if (auto loweredInfo = mapLoweredTypeToOptionalTypeInfo.TryGetValue(type))
+ return loweredInfo->Ptr();
+
+ if (!type)
+ return nullptr;
+ if (type->getOp() != kIROp_OptionalType)
+ return nullptr;
+
+ RefPtr<LoweredOptionalTypeInfo> info = new LoweredOptionalTypeInfo();
+ info->optionalType = (IRType*)type;
+ auto optionalType = cast<IROptionalType>(type);
+ auto valueType = optionalType->getValueType();
+ info->valueType = valueType;
+
+ auto structType = builder->createStructType();
+ info->loweredType = structType;
+ builder->addNameHintDecoration(structType, UnownedStringSlice("OptionalType"));
+
+ info->valueType = valueType;
+ auto valueKey = builder->createStructKey();
+ builder->addNameHintDecoration(valueKey, UnownedStringSlice("value"));
+ info->valueField = builder->createStructField(structType, valueKey, (IRType*)valueType);
+
+ auto boolType = builder->getBoolType();
+ auto hasValueKey = builder->createStructKey();
+ builder->addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue"));
+ info->hasValueField = builder->createStructField(structType, hasValueKey, (IRType*)boolType);
+
+ mapLoweredTypeToOptionalTypeInfo[info->loweredType] = info;
+ loweredOptionalTypes[type] = info;
+ return info.Ptr();
+ }
+
+ void addToWorkList(
+ IRInst* inst)
+ {
+ for (auto ii = inst->getParent(); ii; ii = ii->getParent())
+ {
+ if (as<IRGeneric>(ii))
+ return;
+ }
+
+ if (workListSet.Contains(inst))
+ return;
+
+ workList.add(inst);
+ workListSet.Add(inst);
+ }
+
+ void processMakeOptionalValue(IRMakeOptionalValue* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto info = getLoweredOptionalType(builder, inst->getDataType());
+ List<IRInst*> operands;
+ operands.add(inst->getOperand(0));
+ operands.add(builder->getBoolValue(true));
+ auto makeStruct = builder->emitMakeStruct(info->loweredType, operands);
+ inst->replaceUsesWith(makeStruct);
+ inst->removeAndDeallocate();
+ }
+
+ void processMakeOptionalNone(IRMakeOptionalNone* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto info = getLoweredOptionalType(builder, inst->getDataType());
+
+ List<IRInst*> operands;
+ operands.add(inst->getDefaultValue());
+ operands.add(builder->getBoolValue(false));
+ auto makeStruct = builder->emitMakeStruct(info->loweredType, operands);
+ inst->replaceUsesWith(makeStruct);
+ inst->removeAndDeallocate();
+ }
+
+ IRInst* getOptionalHasValue(IRBuilder* builder, IRInst* optionalInst)
+ {
+ auto loweredOptionalTypeInfo = getLoweredOptionalType(builder, optionalInst->getDataType());
+ SLANG_ASSERT(loweredOptionalTypeInfo);
+
+ auto value = builder->emitFieldExtract(
+ builder->getBoolType(),
+ optionalInst,
+ loweredOptionalTypeInfo->hasValueField->getKey());
+ return value;
+ }
+
+ void processGetOptionalHasValue(IROptionalHasValue* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto optionalValue = inst->getOptionalOperand();
+ auto hasVal = getOptionalHasValue(builder, optionalValue);
+ inst->replaceUsesWith(hasVal);
+ inst->removeAndDeallocate();
+ }
+
+ void processGetOptionalValue(IRGetOptionalValue* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto base = inst->getOptionalOperand();
+ auto loweredOptionalTypeInfo = getLoweredOptionalType(builder, base->getDataType());
+ SLANG_ASSERT(loweredOptionalTypeInfo);
+ SLANG_ASSERT(loweredOptionalTypeInfo->valueField);
+ auto getElement = builder->emitFieldExtract(
+ loweredOptionalTypeInfo->valueType,
+ base,
+ loweredOptionalTypeInfo->valueField->getKey());
+ inst->replaceUsesWith(getElement);
+ inst->removeAndDeallocate();
+ }
+
+ void processOptionalType(IROptionalType* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto loweredOptionalTypeInfo = getLoweredOptionalType(builder, inst);
+ SLANG_ASSERT(loweredOptionalTypeInfo);
+ SLANG_UNUSED(loweredOptionalTypeInfo);
+ }
+
+ void processInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_MakeOptionalValue:
+ processMakeOptionalValue((IRMakeOptionalValue*)inst);
+ break;
+ case kIROp_MakeOptionalNone:
+ processMakeOptionalNone((IRMakeOptionalNone*)inst);
+ break;
+ case kIROp_OptionalHasValue:
+ processGetOptionalHasValue((IROptionalHasValue*)inst);
+ break;
+ case kIROp_GetOptionalValue:
+ processGetOptionalValue((IRGetOptionalValue*)inst);
+ break;
+ case kIROp_OptionalType:
+ processOptionalType((IROptionalType*)inst);
+ break;
+ default:
+ break;
+ }
+ }
+
+ void processModule()
+ {
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->init(module);
+
+ // Deduplicate equivalent types.
+ sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
+
+ addToWorkList(module->getModuleInst());
+
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = workList.getLast();
+
+ workList.removeLast();
+ workListSet.Remove(inst);
+
+ processInst(inst);
+
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ addToWorkList(child);
+ }
+ }
+
+ // Replace all optional types with lowered struct types.
+ for (auto kv : loweredOptionalTypes)
+ {
+ kv.Key->replaceUsesWith(kv.Value->loweredType);
+ }
+ }
+ };
+
+ void lowerOptionalType(IRModule* module, DiagnosticSink* sink)
+ {
+ OptionalTypeLoweringContext context;
+ context.module = module;
+ context.sink = sink;
+ context.processModule();
+ }
+}
diff --git a/source/slang/slang-ir-lower-optional-type.h b/source/slang/slang-ir-lower-optional-type.h
new file mode 100644
index 000000000..1a011da26
--- /dev/null
+++ b/source/slang/slang-ir-lower-optional-type.h
@@ -0,0 +1,16 @@
+// slang-ir-lower-optional-type.h
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+ struct IRModule;
+ class DiagnosticSink;
+
+ /// Lower `IROptionalType<T,E>` types to ordinary `struct`s.
+ void lowerOptionalType(
+ IRModule* module,
+ DiagnosticSink* sink);
+
+}
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index ffdb84c4a..d67c87ca9 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -96,6 +96,87 @@ struct PeepholeContext : InstPassBase
changed = true;
}
break;
+ case kIROp_IsType:
+ {
+ auto isTypeInst = as<IRIsType>(inst);
+ auto actualType = isTypeInst->getValue()->getDataType();
+ if (isTypeEqual(actualType, (IRType*)isTypeInst->getTypeOperand()))
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ auto trueVal = builder.getBoolValue(true);
+ inst->replaceUsesWith(trueVal);
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ break;
+ case kIROp_Reinterpret:
+ {
+ if (isTypeEqual(inst->getOperand(0)->getDataType(), inst->getDataType()))
+ {
+ inst->replaceUsesWith(inst->getOperand(0));
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ break;
+ case kIROp_UnpackAnyValue:
+ {
+ if (inst->getOperand(0)->getOp() == kIROp_PackAnyValue)
+ {
+ if (isTypeEqual(inst->getOperand(0)->getOperand(0)->getDataType(), inst->getDataType()))
+ {
+ inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ }
+ break;
+ case kIROp_PackAnyValue:
+ {
+ // Pack(obj: anyValueN) : anyValueN --> obj
+ if (isTypeEqual(inst->getOperand(0)->getDataType(), inst->getDataType()))
+ {
+ inst->replaceUsesWith(inst->getOperand(0));
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ break;
+ case kIROp_GetOptionalValue:
+ {
+ if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue)
+ {
+ inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ break;
+ case kIROp_OptionalHasValue:
+ {
+ if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ auto trueVal = builder.getBoolValue(true);
+ inst->replaceUsesWith(trueVal);
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ else if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalNone)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ auto falseVal = builder.getBoolValue(false);
+ inst->replaceUsesWith(falseVal);
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ break;
default:
break;
}
@@ -105,6 +186,7 @@ struct PeepholeContext : InstPassBase
{
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->init(module);
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
changed = false;
processAllInsts([this](IRInst* inst) { processInst(inst); });
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp
index 450867abb..0ca0933f1 100644
--- a/source/slang/slang-ir-specialize-dispatch.cpp
+++ b/source/slang/slang-ir-specialize-dispatch.cpp
@@ -253,8 +253,6 @@ void fixupDispatchFuncCall(SharedGenericsLoweringContext* sharedContext, IRFunc*
}
if (as<IRWitnessTable>(args[0]->getDataType()))
continue;
- auto seqIdArg = builder.emitGetSequentialIDInst(args[0]);
- args[0] = seqIdArg;
auto newCall = builder.emitCallInst(call->getFullType(), newDispatchFunc, args);
call->replaceUsesWith(newCall);
call->removeAndDeallocate();
diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
index 52c8edca6..dbcc7a54c 100644
--- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
+++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
@@ -168,25 +168,15 @@ struct AssociatedTypeLookupSpecializationContext
void processGetSequentialIDInst(IRGetSequentialID* inst)
{
- if (inst->getRTTIOperand()->getDataType()->getOp() == kIROp_WitnessTableIDType)
- {
- // If the operand is a witness table id, just return the operand.
- inst->replaceUsesWith(inst->getRTTIOperand());
- inst->removeAndDeallocate();
- }
- else if (inst->getRTTIOperand()->getDataType()->getOp() == kIROp_VectorType)
- {
- // If the operand is a witness table, it is already replaced with a uint2
- // at this point, where the first element in the uint2 is the id of the
- // witness table.
- auto vectorType = inst->getRTTIOperand()->getDataType();
- IRBuilder builder(sharedContext->sharedBuilderStorage);
- builder.setInsertBefore(inst);
- UInt index = 0;
- auto id = builder.emitSwizzle(as<IRVectorType>(vectorType)->getElementType(), inst->getRTTIOperand(), 1, &index);
- inst->replaceUsesWith(id);
- inst->removeAndDeallocate();
- }
+ // If the operand is a witness table, it is already replaced with a uint2
+ // at this point, where the first element in the uint2 is the id of the
+ // witness table.
+ IRBuilder builder(sharedContext->sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ UInt index = 0;
+ auto id = builder.emitSwizzle(builder.getUIntType(), inst->getRTTIOperand(), 1, &index);
+ inst->replaceUsesWith(id);
+ inst->removeAndDeallocate();
}
void processModule()
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index c66f0d555..fd7cbe408 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2587,7 +2587,7 @@ namespace Slang
IRAnyValueType* IRBuilder::getAnyValueType(IRIntegerValue size)
{
return (IRAnyValueType*)getType(kIROp_AnyValueType,
- getIntValue(getIntType(), size));
+ getIntValue(getUIntType(), size));
}
IRAnyValueType* IRBuilder::getAnyValueType(IRInst* size)
@@ -2624,6 +2624,11 @@ namespace Slang
return (IRResultType*)getType(kIROp_ResultType, 2, operands);
}
+ IROptionalType* IRBuilder::getOptionalType(IRType* valueType)
+ {
+ return (IROptionalType*)getType(kIROp_OptionalType, valueType);
+ }
+
IRBasicBlockType* IRBuilder::getBasicBlockType()
{
return (IRBasicBlockType*)getType(kIROp_BasicBlockType);
@@ -2971,6 +2976,11 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitReinterpret(IRInst* type, IRInst* value)
+ {
+ return emitIntrinsicInst((IRType*)type, kIROp_Reinterpret, 1, &value);
+ }
+
IRLiveRangeStart* IRBuilder::emitLiveRangeStart(IRInst* referenced)
{
// This instruction doesn't produce any result,
@@ -3323,6 +3333,42 @@ namespace Slang
&result);
}
+ IRInst* IRBuilder::emitOptionalHasValue(IRInst* optValue)
+ {
+ return emitIntrinsicInst(
+ getBoolType(),
+ kIROp_OptionalHasValue,
+ 1,
+ &optValue);
+ }
+
+ IRInst* IRBuilder::emitGetOptionalValue(IRInst* optValue)
+ {
+ return emitIntrinsicInst(
+ cast<IROptionalType>(optValue->getDataType())->getValueType(),
+ kIROp_GetOptionalValue,
+ 1,
+ &optValue);
+ }
+
+ IRInst* IRBuilder::emitMakeOptionalValue(IRInst* optType, IRInst* value)
+ {
+ return emitIntrinsicInst(
+ (IRType*)optType,
+ kIROp_MakeOptionalValue,
+ 1,
+ &value);
+ }
+
+ IRInst* IRBuilder::emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue)
+ {
+ return emitIntrinsicInst(
+ (IRType*)optType,
+ kIROp_MakeOptionalNone,
+ 1,
+ &defaultValue);
+ }
+
IRInst* IRBuilder::emitMakeVector(
IRType* type,
UInt argCount,
@@ -3819,6 +3865,14 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitIsType(IRInst* value, IRInst* witness, IRInst* typeOperand, IRInst* targetWitness)
+ {
+ IRInst* args[] = { value, witness, typeOperand, targetWitness };
+ auto inst = createInst<IRIsType>(this, kIROp_IsType, getBoolType(), SLANG_COUNT_OF(args), args);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitFieldExtract(
IRType* type,
IRInst* base,
@@ -6079,6 +6133,10 @@ namespace Slang
case kIROp_GetResultError:
case kIROp_GetResultValue:
case kIROp_IsResultError:
+ case kIROp_MakeOptionalValue:
+ case kIROp_MakeOptionalNone:
+ case kIROp_OptionalHasValue:
+ case kIROp_GetOptionalValue:
case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads
case kIROp_ImageSubscript:
case kIROp_FieldExtract:
@@ -6118,6 +6176,9 @@ namespace Slang
case kIROp_WrapExistential:
case kIROp_BitCast:
case kIROp_AllocObj:
+ case kIROp_PackAnyValue:
+ case kIROp_UnpackAnyValue:
+ case kIROp_Reinterpret:
return false;
}
}
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 47a724def..47c6621f3 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1488,6 +1488,14 @@ struct IRResultType : IRType
IRType* getErrorType() { return (IRType*)getOperand(1); }
};
+/// Represents an `Optional<T>`.
+struct IROptionalType : IRType
+{
+ IR_LEAF_ISA(OptionalType)
+
+ IRType* getValueType() { return (IRType*)getOperand(0); }
+};
+
struct IRTypeType : IRType
{
IR_LEAF_ISA(TypeType);
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index 3de335c33..353c98fa4 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -383,6 +383,24 @@ public:
}
return dispatchIfNotNull(expr->base.exp);
}
+ bool visitAsTypeExpr(AsTypeExpr* expr)
+ {
+ if (dispatchIfNotNull(expr->value))
+ return true;
+ return dispatchIfNotNull(expr->typeExpr);
+ }
+ bool visitIsTypeExpr(IsTypeExpr* expr)
+ {
+ if (dispatchIfNotNull(expr->value))
+ return true;
+ return dispatchIfNotNull(expr->typeExpr.exp);
+ }
+ bool visitMakeOptionalExpr(MakeOptionalExpr* expr)
+ {
+ if (dispatchIfNotNull(expr->typeExpr))
+ return true;
+ return dispatchIfNotNull(expr->value);
+ }
bool visitModifiedTypeExpr(ModifiedTypeExpr* expr) { return dispatchIfNotNull(expr->base.exp); }
bool visitTryExpr(TryExpr* expr) { return dispatchIfNotNull(expr->base); }
diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp
index 21586089a..4ae5bcb37 100644
--- a/source/slang/slang-language-server-completion.cpp
+++ b/source/slang/slang-language-server-completion.cpp
@@ -30,7 +30,8 @@ static const char* kStmtKeywords[] = {
"protected", "typedef", "typealias", "uniform", "export", "groupshared",
"extension", "associatedtype", "this", "namespace", "This", "using",
"__generic", "__exported", "import", "enum", "break", "continue",
- "discard", "defer", "cbuffer", "tbuffer", "func"};
+ "discard", "defer", "cbuffer", "tbuffer", "func", "is",
+ "as", "nullptr", "true", "false"};
static const char* hlslSemanticNames[] = {
"register",
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index be7373d40..aef60c9d9 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3384,6 +3384,24 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
return LoweredValInfo::simple(irLit);
}
+ LoweredValInfo visitMakeOptionalExpr(MakeOptionalExpr* expr)
+ {
+ if (expr->value)
+ {
+ auto val = lowerRValueExpr(context, expr->value);
+ auto optType = lowerType(context, expr->type);
+ auto irVal = context->irBuilder->emitMakeOptionalValue(optType, val.val);
+ return LoweredValInfo::simple(irVal);
+ }
+ else
+ {
+ auto optType = lowerType(context, expr->type);
+ auto defaultVal = getDefaultVal(as<OptionalType>(expr->type)->getValueType());
+ auto irVal = context->irBuilder->emitMakeOptionalNone(optType, defaultVal.val);
+ return LoweredValInfo::simple(irVal);
+ }
+ }
+
LoweredValInfo visitAggTypeCtorExpr(AggTypeCtorExpr* /*expr*/)
{
SLANG_UNIMPLEMENTED_X("codegen for aggregate type constructor expression");
@@ -3911,6 +3929,50 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
UNREACHABLE_RETURN(LoweredValInfo());
}
+ LoweredValInfo visitAsTypeExpr(AsTypeExpr* expr)
+ {
+ auto value = lowerLValueExpr(context, expr->value);
+ auto existentialInfo = value.getExtractedExistentialValInfo();
+ auto optType = lowerType(context, expr->type);
+ SLANG_RELEASE_ASSERT(optType->getOp() == kIROp_OptionalType);
+ auto targetType = optType->getOperand(0);
+ auto witness = lowerSimpleVal(context, expr->witnessArg);
+ auto builder = getBuilder();
+ auto var = builder->emitVar(optType);
+ auto isType = builder->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, targetType, witness);
+ IRBlock* trueBlock;
+ IRBlock* falseBlock;
+ IRBlock* afterBlock;
+ builder->emitIfElseWithBlocks(isType, trueBlock, falseBlock, afterBlock);
+ builder->setInsertInto(trueBlock);
+ auto irVal = builder->emitReinterpret(targetType, existentialInfo->extractedVal);
+ auto optionalVal = builder->emitMakeOptionalValue(optType, irVal);
+ builder->emitStore(var, optionalVal);
+ builder->emitBranch(afterBlock);
+ builder->setInsertInto(falseBlock);
+ auto defaultVal = getDefaultVal(as<OptionalType>(expr->type)->getValueType());
+ auto noneVal = builder->emitMakeOptionalNone(optType, defaultVal.val);
+ builder->emitStore(var, noneVal);
+ builder->emitBranch(afterBlock);
+ builder->setInsertInto(afterBlock);
+ auto result = builder->emitLoad(var);
+ return LoweredValInfo::simple(result);
+ }
+
+ LoweredValInfo visitIsTypeExpr(IsTypeExpr* expr)
+ {
+ if (expr->isAlwaysTrue)
+ {
+ return LoweredValInfo::simple(getBuilder()->getBoolValue(true));
+ }
+ auto value = lowerLValueExpr(context, expr->value);
+ auto type = lowerType(context, expr->type);
+ auto witness = lowerSimpleVal(context, expr->witnessArg);
+ auto existentialInfo = value.getExtractedExistentialValInfo();
+ auto irVal = getBuilder()->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, type, witness);
+ return LoweredValInfo::simple(irVal);
+ }
+
LoweredValInfo visitModifierCastExpr(
ModifierCastExpr* expr)
{
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 17794a4b5..4baa7211e 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -4735,9 +4735,9 @@ namespace Slang
- Precedence GetOpLevel(Parser* parser, TokenType type)
+ Precedence GetOpLevel(Parser* parser, const Token& token)
{
- switch(type)
+ switch(token.type)
{
case TokenType::QuestionMark:
return Precedence::TernaryConditional;
@@ -4790,6 +4790,10 @@ namespace Slang
case TokenType::OpMod:
return Precedence::Multiplicative;
default:
+ if (token.getContent() == "is" || token.getContent() == "as")
+ {
+ return Precedence::RelationalComparison;
+ }
return Precedence::Invalid;
}
}
@@ -4840,16 +4844,39 @@ namespace Slang
auto expr = inExpr;
for(;;)
{
- auto opTokenType = parser->tokenReader.peekTokenType();
- auto opPrec = GetOpLevel(parser, opTokenType);
+ auto opToken = parser->tokenReader.peekToken();
+ auto opPrec = GetOpLevel(parser, opToken);
if(opPrec < prec)
break;
+ // Special case the "is" and "as" operators.
+ if (opToken.type == TokenType::Identifier)
+ {
+ if (opToken.getContent() == "is")
+ {
+ auto isExpr = parser->astBuilder->create<IsTypeExpr>();
+ isExpr->value = expr;
+ parser->ReadToken();
+ isExpr->typeExpr = parser->ParseTypeExp();
+ expr = isExpr;
+ continue;
+ }
+ else if (opToken.getContent() == "as")
+ {
+ auto asExpr = parser->astBuilder->create<AsTypeExpr>();
+ asExpr->value = expr;
+ parser->ReadToken();
+ asExpr->typeExpr = parser->ParseType();
+ expr = asExpr;
+ continue;
+ }
+ }
+
auto op = parseOperator(parser);
// Special case the `?:` operator since it is the
// one non-binary case we need to deal with.
- if(opTokenType == TokenType::QuestionMark)
+ if(opToken.type == TokenType::QuestionMark)
{
SelectExpr* select = parser->astBuilder->create<SelectExpr>();
select->loc = op->loc;
@@ -4869,7 +4896,7 @@ namespace Slang
for(;;)
{
- auto nextOpPrec = GetOpLevel(parser, parser->tokenReader.peekTokenType());
+ auto nextOpPrec = GetOpLevel(parser, parser->tokenReader.peekToken());
if((GetAssociativityFromLevel(nextOpPrec) == Associativity::Right) ? (nextOpPrec < opPrec) : (nextOpPrec <= opPrec))
break;
@@ -4877,7 +4904,7 @@ namespace Slang
right = parseInfixExprWithPrecedence(parser, right, nextOpPrec);
}
- if (opTokenType == TokenType::OpAssign)
+ if (opToken.type == TokenType::OpAssign)
{
AssignExpr* assignExpr = parser->astBuilder->create<AssignExpr>();
assignExpr->loc = op->loc;
diff --git a/tests/language-feature/interfaces/is-as-dynamic.slang b/tests/language-feature/interfaces/is-as-dynamic.slang
new file mode 100644
index 000000000..4499db53a
--- /dev/null
+++ b/tests/language-feature/interfaces/is-as-dynamic.slang
@@ -0,0 +1,48 @@
+// is-as-dynamic.slang
+
+// Test that `is` and `as` operators works as intended in dynamic dispatch.
+
+//TEST(compute):COMPARE_COMPUTE: -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[anyValueSize(8)]
+interface IFoo
+{
+ int method();
+}
+
+//TEST_INPUT: type_conformance Impl1:IFoo = 0
+struct Impl1 : IFoo
+{
+ int data;
+ int method() { return data; }
+}
+
+//TEST_INPUT: type_conformance Impl2:IFoo = 1
+struct Impl2 : IFoo
+{
+ int data1;
+ int data2;
+ int method() { return data1 + data2; }
+}
+
+int getData(IFoo foo)
+{
+ let castResult = foo as Impl2;
+ if (castResult.hasValue && foo is Impl2 && !(foo is Impl1))
+ {
+ return castResult.value.method();
+ }
+ return 0;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ int2 data = int2(1, 2);
+ IFoo dynamicObject = createDynamicObject<IFoo, int2>(1, data);
+ int outVal = getData(dynamicObject);
+ outputBuffer[0] = outVal;
+}
diff --git a/tests/language-feature/interfaces/is-as-dynamic.slang.expected.txt b/tests/language-feature/interfaces/is-as-dynamic.slang.expected.txt
new file mode 100644
index 000000000..00750edc0
--- /dev/null
+++ b/tests/language-feature/interfaces/is-as-dynamic.slang.expected.txt
@@ -0,0 +1 @@
+3
diff --git a/tests/language-feature/interfaces/is-as.slang b/tests/language-feature/interfaces/is-as.slang
new file mode 100644
index 000000000..2712d9810
--- /dev/null
+++ b/tests/language-feature/interfaces/is-as.slang
@@ -0,0 +1,46 @@
+// is-as.slang
+
+// Test that `is` and `as` operators works as intended.
+
+//TEST(compute):COMPARE_COMPUTE: -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+interface IFoo
+{
+ int method();
+}
+
+struct Impl1 : IFoo
+{
+ int data;
+ int method() { return data; }
+}
+
+struct Impl2 : IFoo
+{
+ int data1;
+ int data2;
+ int method() { return data1 + data2; }
+}
+
+int getData(IFoo foo)
+{
+ let castResult = foo as Impl2;
+ if (castResult.hasValue)
+ {
+ return castResult.value.method();
+ }
+ return 0;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ Impl2 obj;
+ obj.data1 = 1;
+ obj.data2 = 2;
+ int outVal = getData(obj);
+ outputBuffer[0] = outVal;
+}
diff --git a/tests/language-feature/interfaces/is-as.slang.expected.txt b/tests/language-feature/interfaces/is-as.slang.expected.txt
new file mode 100644
index 000000000..00750edc0
--- /dev/null
+++ b/tests/language-feature/interfaces/is-as.slang.expected.txt
@@ -0,0 +1 @@
+3