summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/command-line-slangc-reference.md8
-rw-r--r--include/slang.h45
-rw-r--r--source/core/slang-string-escape-util.cpp16
-rw-r--r--source/core/slang-string-escape-util.h3
-rw-r--r--source/slang-record-replay/record/slang-session.cpp16
-rw-r--r--source/slang-record-replay/record/slang-session.h5
-rw-r--r--source/slang-record-replay/util/record-format.h1
-rw-r--r--source/slang/core.meta.slang5
-rw-r--r--source/slang/diff.meta.slang39
-rw-r--r--source/slang/slang-ast-natural-layout.cpp10
-rw-r--r--source/slang/slang-check-conversion.cpp7
-rw-r--r--source/slang/slang-check-decl.cpp19
-rw-r--r--source/slang/slang-check-impl.h6
-rw-r--r--source/slang/slang-check-shader.cpp124
-rw-r--r--source/slang/slang-compiler-options.cpp1
-rw-r--r--source/slang/slang-compiler.h5
-rw-r--r--source/slang/slang-diagnostic-defs.h11
-rw-r--r--source/slang/slang-emit-c-like.cpp3
-rw-r--r--source/slang/slang-emit.cpp46
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp36
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h2
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp7
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h43
-rw-r--r--source/slang/slang-ir-autodiff.cpp14
-rw-r--r--source/slang/slang-ir-autodiff.h1
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h9
-rw-r--r--source/slang/slang-ir-lower-existential.cpp17
-rw-r--r--source/slang/slang-ir-lower-optional-type.cpp107
-rw-r--r--source/slang/slang-ir-marshal-native-call.cpp11
-rw-r--r--source/slang/slang-ir-peephole.cpp34
-rw-r--r--source/slang/slang-ir.cpp9
-rw-r--r--source/slang/slang-options.cpp15
-rw-r--r--source/slang/slang-parser.cpp5
-rw-r--r--source/slang/slang-type-layout.cpp2
-rw-r--r--source/slang/slang.cpp25
-rw-r--r--tests/autodiff/optional.slang59
-rw-r--r--tests/initializer-list/existential-is-not-c-like.slang21
-rw-r--r--tests/language-feature/interfaces/optional-none.slang47
-rw-r--r--tests/language-feature/interfaces/zero-init-interface.slang33
-rw-r--r--tests/language-feature/nested-optional.slang35
41 files changed, 779 insertions, 125 deletions
diff --git a/docs/command-line-slangc-reference.md b/docs/command-line-slangc-reference.md
index 2d684f961..91e3be5b5 100644
--- a/docs/command-line-slangc-reference.md
+++ b/docs/command-line-slangc-reference.md
@@ -343,6 +343,14 @@ Treat enums types as unscoped by default.
Preserve all resource parameters in the output code, even if they are not used by the shader.
+<a id="conformance"></a>
+### -conformance
+
+**-conformance &lt;typeName&gt;:&lt;interfaceName&gt;\[=&lt;sequentialID&gt;\]**
+
+Include additional type conformance during linking for dynamic dispatch.
+
+
<a id="reflection-json"></a>
### -reflection-json
diff --git a/include/slang.h b/include/slang.h
index dc52e80b3..11538bcac 100644
--- a/include/slang.h
+++ b/include/slang.h
@@ -906,7 +906,6 @@ typedef uint32_t SlangSizeT;
DisableSourceMap, // bool
UnscopedEnum, // bool
PreserveParameters, // bool: preserve all resource parameters in the output code.
-
// Target
Capability, // intValue0: CapabilityName
@@ -998,8 +997,13 @@ typedef uint32_t SlangSizeT;
TrackLiveness,
LoopInversion, // bool, enable loop inversion optimization
- // Deprecated
- ParameterBlocksUseRegisterSpaces,
+ ParameterBlocksUseRegisterSpaces, // Deprecated
+ LanguageVersion, // intValue0: SlangLanguageVersion
+ TypeConformance, // stringValue0: additional type conformance to link, in the format of
+ // "<TypeName>:<IInterfaceName>[=<sequentialId>]", for example
+ // "Impl:IFoo=3" or "Impl:IFoo".
+ EnableExperimentalDynamicDispatch, // bool, experimental
+ EmitReflectionJSON, // bool
CountOfParsableOptions,
@@ -1016,14 +1020,11 @@ typedef uint32_t SlangSizeT;
// Setting of EmitSpirvDirectly or EmitSpirvViaGLSL will turn into this option internally.
EmitSpirvMethod, // enum SlangEmitSpirvMethod
- EmitReflectionJSON, // bool
SaveGLSLModuleBinSource,
SkipDownstreamLinking, // bool, experimental
DumpModule,
- EnableExperimentalDynamicDispatch, // bool, experimental
- LanguageVersion, // intValue0: SlangLanguageVersion
CountOf,
};
@@ -4052,6 +4053,7 @@ struct ISession : public ISlangUnknown
ISlangBlob** outNameBlob) = 0;
/** Get the sequential ID used to identify a type witness in a dynamic object.
+ The sequential ID is part of the RTTI bytes returned by `getDynamicObjectRTTIBytes`.
*/
virtual SLANG_NO_THROW SlangResult SLANG_MCALL getTypeConformanceWitnessSequentialID(
slang::TypeReflection* type,
@@ -4113,6 +4115,37 @@ struct ISession : public ISlangUnknown
const char* path,
const char* string,
slang::IBlob** outDiagnostics = nullptr) = 0;
+
+
+ /** Get the 16-byte RTTI header to fill into a dynamic object.
+ This header is used to identify the type of the object for dynamic dispatch purpose.
+ For example, given the following shader:
+
+ ```slang
+ [anyValueSize(32)] dyn interface IFoo { int eval(); }
+ struct Impl : IFoo { int eval() { return 1; } }
+
+ ConstantBuffer<dyn IFoo> cb0;
+
+ [numthreads(1,1,1)
+ void main()
+ {
+ cb0.eval();
+ }
+ ```
+
+ The constant buffer `cb0` should be filled with 16+32=48 bytes of data, where the first
+ 16 bytes should be the RTTI bytes returned by calling `getDynamicObjectRTTIBytes(type_Impl,
+ type_IFoo)`, and the rest 32 bytes should hold the actual data of the dynamic object (in
+ this case, fields in the `Impl` type).
+
+ `bufferSizeInBytes` must be greater than 16.
+ */
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL getDynamicObjectRTTIBytes(
+ slang::TypeReflection* type,
+ slang::TypeReflection* interfaceType,
+ uint32_t* outRTTIDataBuffer,
+ uint32_t bufferSizeInBytes) = 0;
};
#define SLANG_UUID_ISession ISession::getTypeGuid()
diff --git a/source/core/slang-string-escape-util.cpp b/source/core/slang-string-escape-util.cpp
index 0645d94ba..c079b8b39 100644
--- a/source/core/slang-string-escape-util.cpp
+++ b/source/core/slang-string-escape-util.cpp
@@ -1099,6 +1099,22 @@ StringEscapeUtil::Handler* StringEscapeUtil::getHandler(Style style)
}
}
+/* static */ UnownedStringSlice StringEscapeUtil::maybeUnquoteCommandLineArg(
+ UnownedStringSlice slice)
+{
+ // If the slice is quoted, unquote it, else return as is
+ if (slice.startsWith("\'") || slice.startsWith("\""))
+ {
+ const Index len = slice.getLength();
+ if (len >= 2 && slice[len - 1] == slice[0])
+ {
+ // Unquote it
+ return UnownedStringSlice(slice.begin() + 1, len - 2);
+ }
+ }
+ return slice;
+}
+
/* static */ bool StringEscapeUtil::isQuoted(char quoteChar, UnownedStringSlice& slice)
{
const Index len = slice.getLength();
diff --git a/source/core/slang-string-escape-util.h b/source/core/slang-string-escape-util.h
index ece8de79f..07b3bcc3d 100644
--- a/source/core/slang-string-escape-util.h
+++ b/source/core/slang-string-escape-util.h
@@ -79,6 +79,9 @@ struct StringEscapeUtil
return isQuoted(handler->getQuoteChar(), slice);
}
+ /// Given a command line arg slice, if it is quoted, unquotes it, else returns the slice as is.
+ static UnownedStringSlice maybeUnquoteCommandLineArg(UnownedStringSlice slice);
+
/// If quoting is needed appends to out quoted
static SlangResult appendMaybeQuoted(
Handler* handler,
diff --git a/source/slang-record-replay/record/slang-session.cpp b/source/slang-record-replay/record/slang-session.cpp
index d290afe0d..800d690fa 100644
--- a/source/slang-record-replay/record/slang-session.cpp
+++ b/source/slang-record-replay/record/slang-session.cpp
@@ -369,6 +369,22 @@ SLANG_NO_THROW SlangResult SessionRecorder::getTypeConformanceWitnessMangledName
return result;
}
+SLANG_NO_THROW SlangResult SessionRecorder::getDynamicObjectRTTIBytes(
+ slang::TypeReflection* type,
+ slang::TypeReflection* interfaceType,
+ uint32_t* outRTTIDataBuffer,
+ uint32_t bufferSizeInBytes)
+{
+ // No need to record this function, it's just a query.
+
+ SlangResult result = m_actualSession->getDynamicObjectRTTIBytes(
+ type,
+ interfaceType,
+ outRTTIDataBuffer,
+ bufferSizeInBytes);
+ return result;
+}
+
SLANG_NO_THROW SlangResult SessionRecorder::getTypeConformanceWitnessSequentialID(
slang::TypeReflection* type,
slang::TypeReflection* interfaceType,
diff --git a/source/slang-record-replay/record/slang-session.h b/source/slang-record-replay/record/slang-session.h
index ea76d0dde..9cff7beac 100644
--- a/source/slang-record-replay/record/slang-session.h
+++ b/source/slang-record-replay/record/slang-session.h
@@ -69,6 +69,11 @@ public:
slang::TypeReflection* type,
slang::TypeReflection* interfaceType,
uint32_t* outId) override;
+ SLANG_NO_THROW SlangResult SLANG_MCALL getDynamicObjectRTTIBytes(
+ slang::TypeReflection* type,
+ slang::TypeReflection* interfaceType,
+ uint32_t* outRTTIDataBuffer,
+ uint32_t bufferSizeInBytes) override;
SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType(
slang::TypeReflection* type,
slang::TypeReflection* interfaceType,
diff --git a/source/slang-record-replay/util/record-format.h b/source/slang-record-replay/util/record-format.h
index f1ae2e71b..99915c46f 100644
--- a/source/slang-record-replay/util/record-format.h
+++ b/source/slang-record-replay/util/record-format.h
@@ -112,7 +112,6 @@ enum ApiCallId : uint32_t
ISession_getLoadedModule = makeApiCallId(Class_ISession, 0x0012),
ISession_isBinaryModuleUpToDate = makeApiCallId(Class_ISession, 0x0013),
-
IModule_findEntryPointByName = makeApiCallId(Class_IModule, 0x0001),
IModule_getDefinedEntryPointCount = makeApiCallId(Class_IModule, 0x0002),
IModule_getDefinedEntryPoint = makeApiCallId(Class_IModule, 0x0003),
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 140c9ba16..484f51bfc 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1463,7 +1463,7 @@ typealias __Addr<T> = Ptr<T, $( (uint64_t)AddressSpace::Generic)ULL>;
__generic<T>
__magic_type(OptionalType)
__intrinsic_type($(kIROp_OptionalType))
-struct Optional
+struct Optional : IDefaultInitializable
{
/// Return `true` iff this `Optional` contains a value of type `T`
property bool hasValue
@@ -1482,6 +1482,9 @@ struct Optional
__implicit_conversion($(kConversionCost_ValToOptional))
__intrinsic_op($(kIROp_MakeOptionalValue))
__init(T val);
+
+ [__unsafeForceInlineEarly]
+ __init() { this = none; }
};
//@hidden:
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index c24a8b11a..13c5d2d47 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -1396,6 +1396,45 @@ extension Array<T, N> : IDifferentiablePtrType
typedef Array<T.Differential, N> Differential;
}
+__generic<T:IDifferentiable>
+extension Optional<T> : IDifferentiable
+{
+ typedef Optional<T.Differential> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return Optional<T.Differential>();
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ if (!a.hasValue)
+ return b;
+ if (b.hasValue)
+ return T.dadd(a.value, b.value);
+ else
+ return a;
+ }
+
+ __generic<U : __BuiltinRealType>
+ [__unsafeForceInlineEarly]
+ static Differential dmul(U a, Differential b)
+ {
+ if (b.hasValue)
+ return Optional<T.Differential>(T.dmul<U>(a, b.value));
+ else
+ return b;
+ }
+}
+
+__generic<T : IDifferentiablePtrType>
+extension Optional<T> : IDifferentiablePtrType
+{
+ typedef Optional<T.Differential> Differential;
+}
+
__generic<each T : IDifferentiable>
extension Tuple<T> : IDifferentiable
{
diff --git a/source/slang/slang-ast-natural-layout.cpp b/source/slang/slang-ast-natural-layout.cpp
index 8bfc5f8ce..f15dee1d1 100644
--- a/source/slang/slang-ast-natural-layout.cpp
+++ b/source/slang/slang-ast-natural-layout.cpp
@@ -4,6 +4,7 @@
#include "slang-ast-builder.h"
// For BaseInfo
+#include "slang-check-impl.h"
#include "slang-compiler.h"
namespace Slang
@@ -165,6 +166,15 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type)
return size;
}
+ else if (auto optionalType = as<OptionalType>(type))
+ {
+ if (isNullableType(optionalType->getValueType()))
+ return calcSize(optionalType->getValueType());
+ NaturalSize size = NaturalSize::makeEmpty();
+ size.append(calcSize(m_astBuilder->getBoolType()));
+ size.append(calcSize(optionalType->getValueType()));
+ return size;
+ }
else if (auto declRefType = as<DeclRefType>(type))
{
if (const auto enumDeclRef = declRefType->getDeclRef().as<EnumDecl>())
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index fcbc83673..41355597f 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -250,6 +250,13 @@ bool SemanticsVisitor::isCStyleType(Type* type, HashSet<Type*>& isVisit)
as<PtrType>(type))
return cacheResult(true);
+ // Slang 2026 language fix: an interface type is not C-style.
+ if (isSlang2026OrLater(this))
+ {
+ // TODO: some/dyn types are also not C-style.
+ if (isDeclRefTypeOf<InterfaceDecl>(type))
+ return cacheResult(false);
+ }
// A tuple type is C-style if all of its members are C-style.
if (auto tupleType = as<TupleType>(type))
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 5aff41988..e3b05ec00 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -34,7 +34,7 @@ static bool isAssociatedTypeDecl(Decl* decl)
return false;
}
-static bool isSlang2026OrLater(SemanticsVisitor* visitor)
+bool isSlang2026OrLater(SemanticsVisitor* visitor)
{
return visitor->getShared()->m_module->getModuleDecl()->languageVersion >=
SLANG_LANGUAGE_VERSION_2026;
@@ -1604,6 +1604,23 @@ EnumDecl* isEnumType(Type* type)
return nullptr;
}
+bool isNullableType(Type* type)
+{
+ if (as<PtrTypeBase>(type))
+ return true;
+ if (isDeclRefTypeOf<InterfaceDecl>(type))
+ return true;
+ if (isDeclRefTypeOf<ClassDecl>(type))
+ return true;
+ if (as<OptionalType>(type))
+ return true;
+ if (as<RefTypeBase>(type))
+ return true;
+ if (as<NativeStringType>(type))
+ return true;
+ return false;
+}
+
bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state)
{
if (state < DeclCheckState::DefinitionChecked)
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 80436e68a..1cdebb115 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -42,6 +42,8 @@ bool isUnsafeForceInlineFunc(FunctionDeclBase* funcDecl);
bool isUniformParameterType(Type* type);
+bool isSlang2026OrLater(SemanticsVisitor* visitor);
+
/// Create a new component type based on `inComponentType`, but with all its requiremetns filled.
RefPtr<ComponentType> fillRequirements(ComponentType* inComponentType);
@@ -3115,6 +3117,10 @@ bool isUnsizedArrayType(Type* type);
bool isInterfaceType(Type* type);
+// Check if `type` is nullable. An `Optional<T>` will occupy the same space as `T`, if `T`
+// is nullable.
+bool isNullableType(Type* type);
+
EnumDecl* isEnumType(Type* type);
DeclVisibility getDeclVisibility(Decl* decl);
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 6a713f412..a57e01a88 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -915,6 +915,51 @@ RefPtr<ComponentType> fillRequirements(ComponentType* inComponentType)
return componentType;
}
+bool parseTypeConformanceArgString(
+ UnownedStringSlice optionString,
+ UnownedStringSlice& outTypeName,
+ UnownedStringSlice& outInterfaceName,
+ Index& outSequentialId)
+{
+ // The expected format for the type conformance argument is:
+ // `TypeName:InterfaceName[=SequentialId]`
+ //
+ // Where `TypeName` is the name of a concrete type, `InterfaceName`
+ // is the name of an interface type, and `SequentialId` is an optional
+ // integer that specifies a sequential ID for the conformance.
+ //
+ // If the string does not match this format, we will return false.
+
+ outTypeName = UnownedStringSlice();
+ outInterfaceName = UnownedStringSlice();
+ outSequentialId = -1;
+ auto colonPos = optionString.indexOf(':');
+ if (colonPos < 0)
+ {
+ // If there is no colon, then the string is invalid.
+ return false;
+ }
+ outTypeName = optionString.head(colonPos);
+ auto interfaceNameStart = colonPos + 1;
+ auto equalsPos = optionString.indexOf('=');
+ if (equalsPos < interfaceNameStart)
+ {
+ // If there is no equals sign, then the interface name goes to the end of the string.
+ outInterfaceName = optionString.tail(interfaceNameStart);
+ }
+ else
+ {
+ // If there is an equals sign, then the interface name goes up to that point.
+ outInterfaceName =
+ optionString.subString(interfaceNameStart, equalsPos - interfaceNameStart);
+ // The sequential ID is the part after the equals sign.
+ auto sequentialIdString = optionString.tail(equalsPos + 1);
+ if (SLANG_FAILED(StringUtil::parseInt(sequentialIdString, outSequentialId)))
+ return false;
+ }
+ return true;
+}
+
/// Create a component type to represent the "global scope" of a compile request.
///
/// This component type will include all the modules and their global
@@ -965,6 +1010,85 @@ RefPtr<ComponentType> createUnspecializedGlobalComponentType(FrontEndCompileRequ
CompositeComponentType::create(linkage, translationUnitComponentTypes);
}
+ List<RefPtr<ComponentType>> conformanceComponents;
+
+ // Find and include all type conformances specified through compiler options.
+ for (auto conformances :
+ compileRequest->optionSet.getArray(CompilerOptionName::TypeConformance))
+ {
+ auto stringValue = conformances.stringValue.getUnownedSlice();
+ UnownedStringSlice typeName, interfaceName;
+ Index sequentialId = -1;
+ if (!parseTypeConformanceArgString(stringValue, typeName, interfaceName, sequentialId))
+ {
+ compileRequest->getSink()->diagnose(
+ SourceLoc(),
+ Diagnostics::invalidTypeConformanceOptionString,
+ stringValue);
+ continue;
+ }
+ auto concreteType = globalComponentType->getTypeFromString(
+ String(typeName).getBuffer(),
+ compileRequest->getSink());
+ if (!concreteType)
+ {
+ compileRequest->getSink()->diagnose(
+ SourceLoc(),
+ Diagnostics::invalidTypeConformanceOptionNoType,
+ stringValue,
+ typeName);
+ continue;
+ }
+ auto interfaceType = globalComponentType->getTypeFromString(
+ String(interfaceName).getBuffer(),
+ compileRequest->getSink());
+ if (!interfaceType)
+ {
+ compileRequest->getSink()->diagnose(
+ SourceLoc(),
+ Diagnostics::invalidTypeConformanceOptionNoType,
+ stringValue,
+ interfaceName);
+ continue;
+ }
+ ComPtr<slang::ITypeConformance> conformanceComponent;
+ ComPtr<ISlangBlob> diagnostics;
+ compileRequest->getLinkage()->createTypeConformanceComponentType(
+ (slang::TypeReflection*)concreteType,
+ (slang::TypeReflection*)interfaceType,
+ conformanceComponent.writeRef(),
+ sequentialId,
+ diagnostics.writeRef());
+ if (!conformanceComponent)
+ {
+ // If we failed to create the conformance component, then
+ // we should report the diagnostics that were generated.
+ //
+ compileRequest->getSink()->diagnose(
+ SourceLoc(),
+ Diagnostics::cannotCreateTypeConformance,
+ stringValue);
+ if (diagnostics)
+ {
+ compileRequest->getSink()->diagnoseRaw(
+ Severity::Error,
+ UnownedStringSlice((char*)diagnostics->getBufferPointer()));
+ }
+ continue;
+ }
+ conformanceComponents.add(static_cast<TypeConformance*>(conformanceComponent.get()));
+ }
+
+ if (conformanceComponents.getCount() > 0)
+ {
+ // If we found any type conformances, then we will
+ // create a composite component type that includes
+ // the global component type and the conformance components.
+ //
+ conformanceComponents.add(globalComponentType);
+ globalComponentType = CompositeComponentType::create(linkage, conformanceComponents);
+ }
+
return fillRequirements(globalComponentType);
}
diff --git a/source/slang/slang-compiler-options.cpp b/source/slang/slang-compiler-options.cpp
index 5c17121cc..843e0e7cb 100644
--- a/source/slang/slang-compiler-options.cpp
+++ b/source/slang/slang-compiler-options.cpp
@@ -198,6 +198,7 @@ bool CompilerOptionSet::allowDuplicate(CompilerOptionName name)
case CompilerOptionName::DownstreamArgs:
case CompilerOptionName::VulkanBindShift:
case CompilerOptionName::VulkanBindShiftAll:
+ case CompilerOptionName::TypeConformance:
return true;
}
return false;
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index e06278599..8f7f7d49a 100644
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -2192,6 +2192,11 @@ public:
slang::TypeReflection* type,
slang::TypeReflection* interfaceType,
uint32_t* outId) override;
+ SLANG_NO_THROW SlangResult SLANG_MCALL getDynamicObjectRTTIBytes(
+ slang::TypeReflection* type,
+ slang::TypeReflection* interfaceType,
+ uint32_t* outBytes,
+ uint32_t bufferSize) override;
SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType(
slang::TypeReflection* type,
slang::TypeReflection* interfaceType,
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 4aadfd78d..465602a33 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -249,6 +249,17 @@ DIAGNOSTIC(
cannotMatchOutputFileToEntryPoint,
"the output path '$0' is not associated with any entry point; a '-o' option for a compiled "
"kernel must follow the '-entry' option for its corresponding entry point")
+DIAGNOSTIC(
+ 71,
+ Error,
+ invalidTypeConformanceOptionString,
+ "syntax error in type conformance option '$0'.")
+DIAGNOSTIC(
+ 72,
+ Error,
+ invalidTypeConformanceOptionNoType,
+ "invalid conformance option '$0', type '$0' is not found.")
+DIAGNOSTIC(73, Error, cannotCreateTypeConformance, "cannot create type conformance '$0'.")
DIAGNOSTIC(
80,
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 26b964d6e..0534b64d4 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -5331,7 +5331,8 @@ void CLikeSourceEmitter::computeEmitActions(IRModule* module, List<EmitAction>&
// Skip resource types in this pass.
if (isResourceType(inst->getDataType()))
continue;
-
+ if (as<IRInterfaceRequirementEntry>(inst))
+ continue;
ensureGlobalInst(&ctx, inst, EmitAction::Level::Definition);
}
}
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index a4362b912..20459c722 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -778,9 +778,6 @@ Result linkAndOptimizeIR(
break;
}
- if (requiredLoweringPassSet.optionalType)
- lowerOptionalType(irModule, sink);
-
switch (target)
{
case CodeGenTarget::CUDASource:
@@ -792,20 +789,6 @@ Result linkAndOptimizeIR(
break;
}
- switch (target)
- {
- case CodeGenTarget::CPPSource:
- case CodeGenTarget::HostCPPSource:
- {
- lowerComInterfaces(irModule, artifactDesc.style, sink);
- generateDllImportFuncs(codeGenContext->getTargetProgram(), irModule, sink);
- generateDllExportFuncs(irModule, sink);
- break;
- }
- default:
- break;
- }
-
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "UNIONS DESUGARED");
#endif
@@ -947,12 +930,6 @@ Result linkAndOptimizeIR(
break;
}
- // Lower `Result<T,E>` types into ordinary struct types. This must happen
- // after specialization, since otherwise incompatible copies of the lowered
- // result structure are generated.
- if (requiredLoweringPassSet.resultType)
- lowerResultType(irModule, sink);
-
// Report checkpointing information
if (codeGenContext->shouldReportCheckpointIntermediates())
{
@@ -978,6 +955,29 @@ Result linkAndOptimizeIR(
finalizeSpecialization(irModule);
+ // Lower `Result<T,E>` types into ordinary struct types. This must happen
+ // after specialization, since otherwise incompatible copies of the lowered
+ // result structure are generated.
+ if (requiredLoweringPassSet.resultType)
+ lowerResultType(irModule, sink);
+
+ if (requiredLoweringPassSet.optionalType)
+ lowerOptionalType(irModule, sink);
+
+ switch (target)
+ {
+ case CodeGenTarget::CPPSource:
+ case CodeGenTarget::HostCPPSource:
+ {
+ lowerComInterfaces(irModule, artifactDesc.style, sink);
+ generateDllImportFuncs(codeGenContext->getTargetProgram(), irModule, sink);
+ generateDllExportFuncs(irModule, sink);
+ break;
+ }
+ default:
+ break;
+ }
+
requiredLoweringPassSet = {};
calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst());
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index df690a4e2..003790793 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1308,6 +1308,32 @@ InstPair ForwardDiffTranscriber::transcribeGetTupleElement(IRBuilder* builder, I
return InstPair(primalGetElement, diffGetElement);
}
+InstPair ForwardDiffTranscriber::transcribeGetOptionalValue(
+ IRBuilder* builder,
+ IRInst* originalInst)
+{
+ IRInst* origBase = originalInst->getOperand(0);
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType());
+
+ IRInst* primalGetOptionalVal =
+ builder->emitIntrinsicInst(primalType, originalInst->getOp(), 1, &primalBase);
+
+ IRInst* diffGetOptionalVal = nullptr;
+
+ if (auto diffType = differentiateType(builder, primalGetOptionalVal->getDataType()))
+ {
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
+ {
+ diffGetOptionalVal =
+ builder->emitIntrinsicInst(diffType, originalInst->getOp(), 1, &diffBase);
+ }
+ }
+
+ return InstPair(primalGetOptionalVal, diffGetOptionalVal);
+}
+
InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst)
{
auto updateInst = as<IRUpdateElement>(originalInst);
@@ -2020,6 +2046,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_MakeArray:
case kIROp_MakeArrayFromElement:
case kIROp_MakeTuple:
+ case kIROp_MakeOptionalValue:
+ case kIROp_MakeResultValue:
case kIROp_MakeValuePack:
case kIROp_BuiltinCast:
return transcribeConstruct(builder, origInst);
@@ -2063,6 +2091,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_GetTupleElement:
return transcribeGetTupleElement(builder, origInst);
+ case kIROp_GetOptionalValue:
+ return transcribeGetOptionalValue(builder, origInst);
case kIROp_ifElse:
return transcribeIfElse(builder, as<IRIfElse>(origInst));
@@ -2197,6 +2227,12 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_MakeCoopVectorFromValuePack:
case kIROp_GetCurrentStage:
case kIROp_GetOffsetPtr:
+ case kIROp_IsNullExistential:
+ case kIROp_MakeResultError:
+ case kIROp_IsResultError:
+ case kIROp_GetResultError:
+ case kIROp_MakeOptionalNone:
+ case kIROp_OptionalHasValue:
return transcribeNonDiffInst(builder, origInst);
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 09b3f14b8..1bdbb01c8 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -72,6 +72,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeGetTupleElement(IRBuilder* builder, IRInst* origInst);
+ InstPair transcribeGetOptionalValue(IRBuilder* builder, IRInst* originalInst);
+
InstPair transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst);
InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index d67d75997..d3d5d72a9 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -390,6 +390,13 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
diffTypeList.getBuffer());
}
+ case kIROp_OptionalType:
+ {
+ auto origOptionalType = as<IROptionalType>(primalType);
+ auto diffValueType = differentiateType(builder, origOptionalType->getValueType());
+ return builder->getOptionalType(diffValueType);
+ }
+
default:
return (IRType*)maybeCloneForPrimalInst(
builder,
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 606428efe..09f70725a 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1476,6 +1476,9 @@ struct DiffTransposePass
case kIROp_GetElement:
return transposeGetElement(builder, as<IRGetElement>(fwdInst), revValue);
+ case kIROp_GetOptionalValue:
+ return transposeGetOptionalValue(builder, as<IRGetOptionalValue>(fwdInst), revValue);
+
case kIROp_Return:
return transposeReturn(builder, as<IRReturn>(fwdInst), revValue);
@@ -1531,7 +1534,8 @@ struct DiffTransposePass
return transposeMakeTuple(builder, fwdInst, revValue);
case kIROp_MakeArrayFromElement:
return transposeMakeArrayFromElement(builder, fwdInst, revValue);
-
+ case kIROp_MakeOptionalValue:
+ return transposeMakeOptionalValue(builder, fwdInst, revValue);
case kIROp_UpdateElement:
return transposeUpdateElement(builder, fwdInst, revValue);
@@ -1673,6 +1677,20 @@ struct DiffTransposePass
fwdGetElement)));
}
+ TranspositionResult transposeGetOptionalValue(
+ IRBuilder* builder,
+ IRGetOptionalValue* fwdGetOptionalValue,
+ IRInst* revValue)
+ {
+ // dP = GetOptionalValue(dVal) -> dVal = MakeOptionalValue(dP)
+ auto optionalVal = fwdGetOptionalValue->getOperand(0);
+ return TranspositionResult(List<RevGradient>(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdGetOptionalValue->getOperand(0),
+ builder->emitMakeOptionalValue(optionalVal->getDataType(), revValue),
+ fwdGetOptionalValue)));
+ }
+
TranspositionResult transposeMakePair(
IRBuilder*,
IRMakeDifferentialPair* fwdMakePair,
@@ -1982,6 +2000,29 @@ struct DiffTransposePass
return TranspositionResult(gradients);
}
+ TranspositionResult transposeMakeOptionalValue(
+ IRBuilder* builder,
+ IRInst* fwdMakeOptionalValue,
+ IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+
+ auto gradAtField = builder->emitGetOptionalValue(revValue);
+ auto diffZero = diffTypeContext.emitDZeroOfDiffInstType(
+ builder,
+ tryGetPrimalTypeFromDiffInst(fwdMakeOptionalValue->getOperand(0)));
+ IRInst* selectArgs[] = {builder->emitOptionalHasValue(revValue), gradAtField, diffZero};
+ builder->emitIntrinsicInst(gradAtField->getDataType(), kIROp_Select, 3, selectArgs);
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeOptionalValue->getOperand(0),
+ gradAtField,
+ fwdMakeOptionalValue));
+
+ // (A = MakeOptionalValue(F)) -> [(dF += dA.hasValue?dA.value:dzero)]
+ return TranspositionResult(gradients);
+ }
+
TranspositionResult transposeMakeStruct(
IRBuilder* builder,
IRInst* fwdMakeStruct,
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index c8dc3b480..133c257a8 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1754,7 +1754,19 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(
(UInt)diffTypeList.getCount(),
diffTypeList.getBuffer());
}
-
+ case kIROp_OptionalType:
+ {
+ auto primalOptionalType = as<IROptionalType>(primalType);
+ if (auto diffElementType =
+ differentiateType(builder, primalOptionalType->getValueType()))
+ {
+ return builder->getOptionalType(diffElementType);
+ }
+ else
+ {
+ return nullptr;
+ }
+ }
default:
return (IRType*)getDifferentialForType(builder, (IRType*)primalType);
}
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index befd1f98a..970f490c9 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -320,6 +320,7 @@ struct DifferentiableTypeConformanceContext
}
case kIROp_TupleType:
case kIROp_TypePack:
+ case kIROp_OptionalType:
{
return differentiateType(builder, origType);
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 3b45d46b3..7a281bac4 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -1225,7 +1225,7 @@ INST(GetValueFromBoundInterface, getValueFromBoundInterface, 1, 0)
INST(ExtractExistentialValue, extractExistentialValue, 1, 0)
INST(ExtractExistentialType, extractExistentialType, 1, HOISTABLE)
INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, HOISTABLE)
-
+INST(IsNullExistential, isNullExistential, 1, 0)
INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0)
INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 2fff4e451..b5c1a6475 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3343,6 +3343,12 @@ struct IRExtractExistentialWitnessTable : IRInst
IR_LEAF_ISA(ExtractExistentialWitnessTable);
};
+struct IRIsNullExistential : IRInst
+{
+ IR_LEAF_ISA(IsNullExistential);
+};
+
+
/* Base class for instructions that track liveness */
struct IRLiveRangeMarker : IRInst
{
@@ -4059,6 +4065,9 @@ public:
/// Given an existential value, extract the underlying "real" type
IRType* emitExtractExistentialType(IRInst* existentialValue);
+ /// Given an existential value, return if it is empty/null.
+ IRInst* emitIsNullExistential(IRInst* existentialValue);
+
/// Given an existential value, extract the witness table showing how the value conforms to the
/// existential type.
IRInst* emitExtractExistentialWitnessTable(IRInst* existentialValue);
diff --git a/source/slang/slang-ir-lower-existential.cpp b/source/slang/slang-ir-lower-existential.cpp
index c4ee6e6fc..076ed57bd 100644
--- a/source/slang/slang-ir-lower-existential.cpp
+++ b/source/slang/slang-ir-lower-existential.cpp
@@ -131,6 +131,19 @@ struct ExistentialLoweringContext
processExtractExistentialElement(inst, 2);
}
+ void processIsNullExistential(IRIsNullExistential* inst)
+ {
+ IRBuilder builder(sharedContext->module);
+ builder.setInsertBefore(inst);
+
+ auto rttiElement = extractTupleElement(&builder, inst->getOperand(0), 0);
+ auto isNull = builder.emitNeq(
+ builder.emitGetElement(builder.getUIntType(), rttiElement, 0),
+ builder.getIntValue(builder.getUIntType(), 0));
+ inst->replaceUsesWith(isNull);
+ inst->removeAndDeallocate();
+ }
+
void processExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst)
{
processExtractExistentialElement(inst, 1);
@@ -261,6 +274,10 @@ struct ExistentialLoweringContext
{
processExtractExistentialWitnessTable(extractExistentialWitnessTable);
}
+ else if (auto isNullExistential = as<IRIsNullExistential>(inst))
+ {
+ processIsNullExistential(isNullExistential);
+ }
}
void processModule()
diff --git a/source/slang/slang-ir-lower-optional-type.cpp b/source/slang/slang-ir-lower-optional-type.cpp
index 5c9dcd722..1f9f398d3 100644
--- a/source/slang/slang-ir-lower-optional-type.cpp
+++ b/source/slang/slang-ir-lower-optional-type.cpp
@@ -8,6 +8,13 @@
namespace Slang
{
+enum LoweredOptionalTypeKind
+{
+ Struct,
+ PtrValue,
+ ExistentialValue,
+};
+
struct OptionalTypeLoweringContext
{
IRModule* module;
@@ -16,10 +23,6 @@ struct OptionalTypeLoweringContext
InstWorkList workList;
InstHashSet workListSet;
- IRGeneric* genericOptionalStructType = nullptr;
- IRStructKey* valueKey = nullptr;
- IRStructKey* hasValueKey = nullptr;
-
OptionalTypeLoweringContext(IRModule* inModule)
: module(inModule), workList(inModule), workListSet(inModule)
{
@@ -30,6 +33,9 @@ struct OptionalTypeLoweringContext
IRType* optionalType = nullptr;
IRType* valueType = nullptr;
IRType* loweredType = nullptr;
+ IRStructKey* hasValueKey = nullptr;
+ IRStructKey* valueKey = nullptr;
+ LoweredOptionalTypeKind kind = LoweredOptionalTypeKind::Struct;
};
Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> mapLoweredTypeToOptionalTypeInfo;
Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> loweredOptionalTypes;
@@ -42,37 +48,29 @@ struct OptionalTypeLoweringContext
return type;
}
- IRInst* getOrCreateGenericOptionalStruct()
+ IRInst* createOptionalStruct(IRType* type, LoweredOptionalTypeInfo* info)
{
- if (genericOptionalStructType)
- return genericOptionalStructType;
IRBuilder builder(module);
builder.setInsertInto(module->getModuleInst());
- valueKey = builder.createStructKey();
- builder.addNameHintDecoration(valueKey, UnownedStringSlice("value"));
- hasValueKey = builder.createStructKey();
- builder.addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue"));
-
- genericOptionalStructType = builder.emitGeneric();
- builder.addNameHintDecoration(
- genericOptionalStructType,
- UnownedStringSlice("_slang_Optional"));
+ info->valueKey = builder.createStructKey();
+ builder.addNameHintDecoration(info->valueKey, UnownedStringSlice("value"));
+ info->hasValueKey = builder.createStructKey();
+ builder.addNameHintDecoration(info->hasValueKey, UnownedStringSlice("hasValue"));
- builder.setInsertInto(genericOptionalStructType);
- auto block = builder.emitBlock();
- auto typeParam = builder.emitParam(builder.getTypeKind());
auto structType = builder.createStructType();
- builder.addNameHintDecoration(structType, UnownedStringSlice("_slang_Optional"));
- builder.createStructField(structType, valueKey, (IRType*)typeParam);
- builder.createStructField(structType, hasValueKey, builder.getBoolType());
- builder.setInsertInto(block);
- builder.emitReturn(structType);
- genericOptionalStructType->setFullType(builder.getTypeKind());
- return genericOptionalStructType;
+ StringBuilder sb;
+ sb << "_slang_Optional_";
+ getTypeNameHint(sb, type);
+ builder.addNameHintDecoration(structType, sb.getUnownedSlice());
+ builder.createStructField(structType, info->valueKey, type);
+ builder.createStructField(structType, info->hasValueKey, builder.getBoolType());
+
+ info->kind = LoweredOptionalTypeKind::Struct;
+ return structType;
}
- bool typeHasNullValue(IRInst* type)
+ bool typeHasNullValue(IRInst* type, LoweredOptionalTypeKind& outKind)
{
switch (type->getOp())
{
@@ -81,21 +79,25 @@ struct OptionalTypeLoweringContext
case kIROp_NativeStringType:
case kIROp_PtrType:
case kIROp_ClassType:
+ outKind = LoweredOptionalTypeKind::PtrValue;
return true;
case kIROp_InterfaceType:
- return isComInterfaceType((IRType*)type);
+ if (isComInterfaceType((IRType*)type))
+ outKind = LoweredOptionalTypeKind::PtrValue;
+ else
+ outKind = LoweredOptionalTypeKind::ExistentialValue;
+ return true;
default:
return false;
}
}
- LoweredOptionalTypeInfo* getLoweredOptionalType(IRBuilder* builder, IRInst* type)
+ LoweredOptionalTypeInfo* getLoweredOptionalType(IRBuilder*, IRInst* type)
{
if (auto loweredInfo = loweredOptionalTypes.tryGetValue(type))
return loweredInfo->Ptr();
if (auto loweredInfo = mapLoweredTypeToOptionalTypeInfo.tryGetValue(type))
return loweredInfo->Ptr();
-
if (!type)
return nullptr;
if (type->getOp() != kIROp_OptionalType)
@@ -104,19 +106,21 @@ struct OptionalTypeLoweringContext
RefPtr<LoweredOptionalTypeInfo> info = new LoweredOptionalTypeInfo();
auto optionalType = cast<IROptionalType>(type);
auto valueType = optionalType->getValueType();
+ while (auto valueOptionalType = as<IROptionalType>(valueType))
+ {
+ // If the value type is also an Optional, we need to keep lowering it.
+ valueType = valueOptionalType->getValueType();
+ }
+
info->optionalType = (IRType*)type;
info->valueType = valueType;
- if (typeHasNullValue(valueType))
+ if (typeHasNullValue(valueType, info->kind))
{
info->loweredType = valueType;
}
else
{
- auto genericType = getOrCreateGenericOptionalStruct();
- IRInst* args[] = {valueType};
- auto specializedType =
- builder->emitSpecializeInst(builder->getTypeKind(), genericType, 1, args);
- info->loweredType = (IRType*)specializedType;
+ info->loweredType = (IRType*)createOptionalStruct(valueType, info);
}
mapLoweredTypeToOptionalTypeInfo[info->loweredType] = info;
loweredOptionalTypes[type] = info;
@@ -171,6 +175,12 @@ struct OptionalTypeLoweringContext
inst->replaceUsesWith(makeStruct);
inst->removeAndDeallocate();
}
+ else if (info->kind == LoweredOptionalTypeKind::ExistentialValue)
+ {
+ auto zero = builder->emitDefaultConstruct(info->loweredType);
+ inst->replaceUsesWith(zero);
+ inst->removeAndDeallocate();
+ }
else
{
inst->replaceUsesWith(builder->getNullPtrValue(info->valueType));
@@ -183,13 +193,20 @@ struct OptionalTypeLoweringContext
auto loweredOptionalTypeInfo = getLoweredOptionalType(builder, optionalInst->getDataType());
SLANG_ASSERT(loweredOptionalTypeInfo);
IRInst* result = nullptr;
- if (loweredOptionalTypeInfo->loweredType != loweredOptionalTypeInfo->valueType)
- {
- result = builder->emitFieldExtract(builder->getBoolType(), optionalInst, hasValueKey);
- }
- else
+ switch (loweredOptionalTypeInfo->kind)
{
+ case LoweredOptionalTypeKind::Struct:
+ result = builder->emitFieldExtract(
+ builder->getBoolType(),
+ optionalInst,
+ loweredOptionalTypeInfo->hasValueKey);
+ break;
+ case LoweredOptionalTypeKind::PtrValue:
result = builder->emitCastPtrToBool(optionalInst);
+ break;
+ case LoweredOptionalTypeKind::ExistentialValue:
+ result = builder->emitIsNullExistential(optionalInst);
+ break;
}
return result;
}
@@ -214,11 +231,13 @@ struct OptionalTypeLoweringContext
auto base = inst->getOptionalOperand();
auto loweredOptionalTypeInfo = getLoweredOptionalType(builder, base->getDataType());
- if (loweredOptionalTypeInfo->loweredType != loweredOptionalTypeInfo->valueType)
+ if (loweredOptionalTypeInfo->kind == LoweredOptionalTypeKind::Struct)
{
SLANG_ASSERT(loweredOptionalTypeInfo);
- auto getElement =
- builder->emitFieldExtract(loweredOptionalTypeInfo->valueType, base, valueKey);
+ auto getElement = builder->emitFieldExtract(
+ loweredOptionalTypeInfo->valueType,
+ base,
+ loweredOptionalTypeInfo->valueKey);
inst->replaceUsesWith(getElement);
}
else
diff --git a/source/slang/slang-ir-marshal-native-call.cpp b/source/slang/slang-ir-marshal-native-call.cpp
index 1d3f04318..f19bf2168 100644
--- a/source/slang/slang-ir-marshal-native-call.cpp
+++ b/source/slang/slang-ir-marshal-native-call.cpp
@@ -17,6 +17,17 @@ IRType* NativeCallMarshallingContext::getNativeType(IRBuilder& builder, IRType*
return builder.getNativePtrType(type);
case kIROp_ComPtrType:
return builder.getNativePtrType((IRType*)as<IRComPtrType>(type)->getOperand(0));
+ case kIROp_ArrayType:
+ case kIROp_UnsizedArrayType:
+ {
+ auto arrayType = as<IRArrayType>(type);
+ auto elementType = arrayType->getElementType();
+ auto nativeElementType = getNativeType(builder, elementType);
+ return builder.getArrayTypeBase(
+ elementType->getOp(),
+ nativeElementType,
+ arrayType->getElementCount());
+ }
case kIROp_InOutType:
case kIROp_RefType:
case kIROp_ConstRefType:
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 28e98fdb6..7f3ff1d68 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -802,13 +802,33 @@ struct PeepholeContext : InstPassBase
{
if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue)
{
- IRBuilder builder(module);
- IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
- builder.setInsertBefore(inst);
- auto trueVal = builder.getBoolValue(true);
- inst->replaceUsesWith(trueVal);
- maybeRemoveOldInst(inst);
- changed = true;
+ auto getHasValue = as<IROptionalHasValue>(inst);
+ auto optionalType =
+ as<IROptionalType>(getHasValue->getOptionalOperand()->getDataType());
+ if (!optionalType)
+ break;
+ if (as<IROptionalType>(optionalType->getValueType()))
+ {
+ // HasValue(o : Optional<Optional<T>>) ==> HasValue(o.value).
+ IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+ builder.setInsertBefore(inst);
+ auto newVal = builder.emitOptionalHasValue(
+ builder.emitGetOptionalValue(getHasValue->getOptionalOperand()));
+ inst->replaceUsesWith(newVal);
+ maybeRemoveOldInst(inst);
+ changed = true;
+ }
+ else
+ {
+ IRBuilder builder(module);
+ IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc);
+ builder.setInsertBefore(inst);
+ auto trueVal = builder.getBoolValue(true);
+ inst->replaceUsesWith(trueVal);
+ maybeRemoveOldInst(inst);
+ changed = true;
+ }
}
else if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalNone)
{
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index f571ec20b..e66ad69ce 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3515,6 +3515,14 @@ IRInst* IRBuilder::emitExtractExistentialValue(IRType* type, IRInst* existential
return inst;
}
+IRInst* IRBuilder::emitIsNullExistential(IRInst* existentialValue)
+{
+ auto inst =
+ createInst<IRInst>(this, kIROp_IsNullExistential, getBoolType(), 1, &existentialValue);
+ addInst(inst);
+ return inst;
+}
+
IRType* IRBuilder::emitExtractExistentialType(IRInst* existentialValue)
{
auto type = getTypeKind();
@@ -8648,6 +8656,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options)
case kIROp_ExtractExistentialType:
case kIROp_ExtractExistentialValue:
case kIROp_ExtractExistentialWitnessTable:
+ case kIROp_IsNullExistential:
case kIROp_WrapExistential:
case kIROp_BuiltinCast:
case kIROp_BitCast:
diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp
index 3ce107135..c269b10c4 100644
--- a/source/slang/slang-options.cpp
+++ b/source/slang/slang-options.cpp
@@ -530,6 +530,10 @@ void initCommandOptions(CommandOptions& options)
nullptr,
"Preserve all resource parameters in the output code, even if they are not used by the "
"shader."},
+ {OptionKind::TypeConformance,
+ "-conformance",
+ "-conformance <typeName>:<interfaceName>[=<sequentialID>]",
+ "Include additional type conformance during linking for dynamic dispatch."},
{OptionKind::EmitReflectionJSON,
"-reflection-json",
"reflection-json <path>",
@@ -2736,6 +2740,17 @@ SlangResult OptionsParser::_parse(int argc, char const* const* argv)
m_compileRequest->addSearchPath(String(slice).getBuffer());
break;
}
+ case OptionKind::TypeConformance:
+ {
+ if (!m_reader.hasArg())
+ break;
+ CommandLineArg operand;
+ SLANG_RETURN_ON_FAIL(m_reader.expectArg(operand));
+ auto unquoted =
+ StringEscapeUtil::maybeUnquoteCommandLineArg(operand.value.getUnownedSlice());
+ linkage->m_optionSet.add(OptionKind::TypeConformance, unquoted);
+ break;
+ }
case OptionKind::Output:
{
//
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 398aab517..431cf6669 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -6124,11 +6124,6 @@ Stmt* Parser::parseIfLetStatement()
tempVarDecl->nameAndLoc = NameLoc(getName(this, "$OptVar"), identifierToken.loc);
tempVarDecl->initExpr = initExpr;
AddMember(currentScope->containerDecl, tempVarDecl);
- if (semanticsVisitor)
- semanticsVisitor->ensureDecl(
- (Decl*)tempVarDecl,
- DeclCheckState::DefinitionChecked,
- nullptr);
DeclStmt* tmpVarDeclStmt = astBuilder->create<DeclStmt>();
FillPosition(tmpVarDeclStmt);
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index 2c72b61de..b2171823b 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -4946,6 +4946,8 @@ static TypeLayoutResult _createTypeLayout(TypeLayoutContext& context, Type* type
else if (auto optionalType = as<OptionalType>(type))
{
// OptionalType should be laid out the same way as Tuple<T, bool>.
+ if (isNullableType(optionalType->getValueType()))
+ return _createTypeLayout(context, optionalType->getValueType());
Array<Type*, 2> types =
makeArray(optionalType->getValueType(), context.astBuilder->getBoolType());
auto tupleType = context.astBuilder->getTupleType(types.getView());
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 9aa8c56a7..065b9de93 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -2029,6 +2029,31 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessSequent
return SLANG_OK;
}
+SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getDynamicObjectRTTIBytes(
+ slang::TypeReflection* type,
+ slang::TypeReflection* interfaceType,
+ uint32_t* outBuffer,
+ uint32_t bufferSize)
+{
+ // Slang RTTI header format:
+ // byte 0-7: pointer to RTTI struct describing the type. (not used for now, set to 1 for valid
+ // types, and 0 to represent null).
+ // byte 8-11: 32-bit sequential ID of the type conformance witness.
+ // byte 12-15: unused.
+
+ if (bufferSize < 16)
+ return SLANG_E_BUFFER_TOO_SMALL;
+
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+
+ SLANG_RETURN_ON_FAIL(getTypeConformanceWitnessSequentialID(type, interfaceType, outBuffer + 2));
+
+ // Make the RTTI part non zero.
+ outBuffer[0] = 1;
+
+ return SLANG_OK;
+}
+
SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentType(
slang::TypeReflection* type,
slang::TypeReflection* interfaceType,
diff --git a/tests/autodiff/optional.slang b/tests/autodiff/optional.slang
new file mode 100644
index 000000000..a86440413
--- /dev/null
+++ b/tests/autodiff/optional.slang
@@ -0,0 +1,59 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -output-using-type
+//TEST(compute,vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -slang -compute -output-using-type
+
+[Differentiable]
+Optional<float> sumSquare(Optional<float> a, Optional<float> b)
+{
+ if (let x = a)
+ {
+ if (let y = b)
+ {
+ return x * x + y * y;
+ }
+ else
+ {
+ return x * x;
+ }
+ }
+ else if (let y = b)
+ {
+ return y * y;
+ }
+ return none;
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ var dpa = diffPair<Optional<float>>(3.0f, none);
+ var dpb = diffPair<Optional<float>>(4.0f, none);
+
+ bwd_diff(sumSquare)(dpa, dpb, 1.0f);
+
+ outputBuffer[0] = -1;
+
+ // CHECK: 14.0
+ if (dpa.d.hasValue && dpb.d.hasValue)
+ outputBuffer[0] = dpa.d.value + dpb.d.value;
+
+ // CHECK: 1.0
+ dpa = diffPair<Optional<float>>(3.0f, none);
+ dpb = diffPair<Optional<float>>(4.0f, none);
+ bwd_diff(sumSquare)(dpa, dpb, none);
+ if (dpa.d.value == 0.0 && dpb.d.value == 0.0)
+ {
+ outputBuffer[1] = 1.0f;
+ }
+
+ // CHECK: 100.0
+ dpa = diffPair<Optional<float>>(none, none);
+ dpb = diffPair<Optional<float>>(4.0f, none);
+ bwd_diff(sumSquare)(dpa, dpb, 1.0);
+ if (dpa.d == none)
+ {
+ outputBuffer[2] = 100.0f;
+ }
+} \ No newline at end of file
diff --git a/tests/initializer-list/existential-is-not-c-like.slang b/tests/initializer-list/existential-is-not-c-like.slang
new file mode 100644
index 000000000..1058033c9
--- /dev/null
+++ b/tests/initializer-list/existential-is-not-c-like.slang
@@ -0,0 +1,21 @@
+// Test that in Slang 2026, it is no longer valid to default initialize an existential value.
+#lang 2026
+
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+interface IBSDF
+{
+ float3 eval(float3 wi, float3 wo);
+}
+
+struct ShaderGraph
+{
+ IBSDF bsdf_stack[8]; // Intentionally uninitialized.
+ int next_bsdf = 0; // must be zero.
+}
+
+[numthreads(1,1,1)]
+void main()
+{
+ // CHECK: ([[# @LINE+1]]): error
+ ShaderGraph sg = {};
+} \ No newline at end of file
diff --git a/tests/language-feature/interfaces/optional-none.slang b/tests/language-feature/interfaces/optional-none.slang
new file mode 100644
index 000000000..04bb8f83e
--- /dev/null
+++ b/tests/language-feature/interfaces/optional-none.slang
@@ -0,0 +1,47 @@
+// Test that the size of an optional interface type is the same as the existential box.
+
+//TEST:SIMPLE(filecheck=CHECK): -target hlsl -conformance "Impl1:IFoo=1" -entry computeMain -profile cs_6_0
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUFFER): -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUFFER): -vk -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+interface IFoo
+{
+ int method();
+}
+
+//TEST_INPUT: type_conformance Impl1:IFoo = 0
+struct Impl1 : IFoo
+{
+ int data;
+ int method() { return data + 1; }
+}
+
+struct MyType
+{
+ Optional<IFoo> foo;
+}
+
+Optional<T> process<T>(Optional<T> opt)
+{
+ return opt;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ MyType t = {};
+ t.foo = process(t.foo);
+
+ // BUFFER: 100
+ if (let f = t.foo)
+ outputBuffer[0] = f.method();
+ else
+ outputBuffer[0] = 100;
+}
+
+// CHECK: struct MyType
+// CHECK-NEXT: {
+// CHECK-NEXT: Tuple{{.*}} foo{{.*}}; \ No newline at end of file
diff --git a/tests/language-feature/interfaces/zero-init-interface.slang b/tests/language-feature/interfaces/zero-init-interface.slang
deleted file mode 100644
index ed3b1eaa4..000000000
--- a/tests/language-feature/interfaces/zero-init-interface.slang
+++ /dev/null
@@ -1,33 +0,0 @@
-// Test that we can zero-init a struct with interface typed member.
-
-//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUFFER): -shaderobj
-//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUFFER): -vk -shaderobj
-
-//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
-RWStructuredBuffer<int> outputBuffer;
-
-interface IFoo
-{
- int method();
-}
-
-//TEST_INPUT: type_conformance Impl1:IFoo = 0
-struct Impl1 : IFoo
-{
- int data;
- int method() { return data + 1; }
-}
-
-struct MyType
-{
- IFoo foo;
-}
-
-
-[numthreads(1, 1, 1)]
-void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
-{
- MyType t = {};
- // BUFFER: 1
- outputBuffer[0] = t.foo.method();
-}
diff --git a/tests/language-feature/nested-optional.slang b/tests/language-feature/nested-optional.slang
new file mode 100644
index 000000000..8fdeac33b
--- /dev/null
+++ b/tests/language-feature/nested-optional.slang
@@ -0,0 +1,35 @@
+//TEST:INTERPRET(filecheck=CHECK):
+
+Optional<Optional<int>> getNone() { return none; }
+
+void main()
+{
+ Optional<Optional<Optional<int>>> val = Optional<Optional<int>>(5);
+ Optional<Optional<Optional<int>>> defaultVal1 = none;
+ Optional<Optional<Optional<int>>> defaultVal2 = getNone();
+
+ // CHECK: 8
+ printf("%d\n", sizeof(val));
+
+ // CHECK: success
+ if (defaultVal1.hasValue == defaultVal2.hasValue)
+ {
+ printf("success\n");
+ }
+ else
+ {
+ printf("failure\n");
+ }
+
+ // CHECK: value: 5
+ if (let x = val)
+ {
+ if (let y = x)
+ {
+ if (let z = y)
+ {
+ printf("value: %d\n", z);
+ }
+ }
+ }
+} \ No newline at end of file