diff options
Diffstat (limited to 'source')
34 files changed, 570 insertions, 86 deletions
diff --git a/source/core/slang-string-escape-util.cpp b/source/core/slang-string-escape-util.cpp index 0645d94ba..c079b8b39 100644 --- a/source/core/slang-string-escape-util.cpp +++ b/source/core/slang-string-escape-util.cpp @@ -1099,6 +1099,22 @@ StringEscapeUtil::Handler* StringEscapeUtil::getHandler(Style style) } } +/* static */ UnownedStringSlice StringEscapeUtil::maybeUnquoteCommandLineArg( + UnownedStringSlice slice) +{ + // If the slice is quoted, unquote it, else return as is + if (slice.startsWith("\'") || slice.startsWith("\"")) + { + const Index len = slice.getLength(); + if (len >= 2 && slice[len - 1] == slice[0]) + { + // Unquote it + return UnownedStringSlice(slice.begin() + 1, len - 2); + } + } + return slice; +} + /* static */ bool StringEscapeUtil::isQuoted(char quoteChar, UnownedStringSlice& slice) { const Index len = slice.getLength(); diff --git a/source/core/slang-string-escape-util.h b/source/core/slang-string-escape-util.h index ece8de79f..07b3bcc3d 100644 --- a/source/core/slang-string-escape-util.h +++ b/source/core/slang-string-escape-util.h @@ -79,6 +79,9 @@ struct StringEscapeUtil return isQuoted(handler->getQuoteChar(), slice); } + /// Given a command line arg slice, if it is quoted, unquotes it, else returns the slice as is. + static UnownedStringSlice maybeUnquoteCommandLineArg(UnownedStringSlice slice); + /// If quoting is needed appends to out quoted static SlangResult appendMaybeQuoted( Handler* handler, diff --git a/source/slang-record-replay/record/slang-session.cpp b/source/slang-record-replay/record/slang-session.cpp index d290afe0d..800d690fa 100644 --- a/source/slang-record-replay/record/slang-session.cpp +++ b/source/slang-record-replay/record/slang-session.cpp @@ -369,6 +369,22 @@ SLANG_NO_THROW SlangResult SessionRecorder::getTypeConformanceWitnessMangledName return result; } +SLANG_NO_THROW SlangResult SessionRecorder::getDynamicObjectRTTIBytes( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + uint32_t* outRTTIDataBuffer, + uint32_t bufferSizeInBytes) +{ + // No need to record this function, it's just a query. + + SlangResult result = m_actualSession->getDynamicObjectRTTIBytes( + type, + interfaceType, + outRTTIDataBuffer, + bufferSizeInBytes); + return result; +} + SLANG_NO_THROW SlangResult SessionRecorder::getTypeConformanceWitnessSequentialID( slang::TypeReflection* type, slang::TypeReflection* interfaceType, diff --git a/source/slang-record-replay/record/slang-session.h b/source/slang-record-replay/record/slang-session.h index ea76d0dde..9cff7beac 100644 --- a/source/slang-record-replay/record/slang-session.h +++ b/source/slang-record-replay/record/slang-session.h @@ -69,6 +69,11 @@ public: slang::TypeReflection* type, slang::TypeReflection* interfaceType, uint32_t* outId) override; + SLANG_NO_THROW SlangResult SLANG_MCALL getDynamicObjectRTTIBytes( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + uint32_t* outRTTIDataBuffer, + uint32_t bufferSizeInBytes) override; SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType( slang::TypeReflection* type, slang::TypeReflection* interfaceType, diff --git a/source/slang-record-replay/util/record-format.h b/source/slang-record-replay/util/record-format.h index f1ae2e71b..99915c46f 100644 --- a/source/slang-record-replay/util/record-format.h +++ b/source/slang-record-replay/util/record-format.h @@ -112,7 +112,6 @@ enum ApiCallId : uint32_t ISession_getLoadedModule = makeApiCallId(Class_ISession, 0x0012), ISession_isBinaryModuleUpToDate = makeApiCallId(Class_ISession, 0x0013), - IModule_findEntryPointByName = makeApiCallId(Class_IModule, 0x0001), IModule_getDefinedEntryPointCount = makeApiCallId(Class_IModule, 0x0002), IModule_getDefinedEntryPoint = makeApiCallId(Class_IModule, 0x0003), diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 140c9ba16..484f51bfc 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1463,7 +1463,7 @@ typealias __Addr<T> = Ptr<T, $( (uint64_t)AddressSpace::Generic)ULL>; __generic<T> __magic_type(OptionalType) __intrinsic_type($(kIROp_OptionalType)) -struct Optional +struct Optional : IDefaultInitializable { /// Return `true` iff this `Optional` contains a value of type `T` property bool hasValue @@ -1482,6 +1482,9 @@ struct Optional __implicit_conversion($(kConversionCost_ValToOptional)) __intrinsic_op($(kIROp_MakeOptionalValue)) __init(T val); + + [__unsafeForceInlineEarly] + __init() { this = none; } }; //@hidden: diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index c24a8b11a..13c5d2d47 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1396,6 +1396,45 @@ extension Array<T, N> : IDifferentiablePtrType typedef Array<T.Differential, N> Differential; } +__generic<T:IDifferentiable> +extension Optional<T> : IDifferentiable +{ + typedef Optional<T.Differential> Differential; + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return Optional<T.Differential>(); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + if (!a.hasValue) + return b; + if (b.hasValue) + return T.dadd(a.value, b.value); + else + return a; + } + + __generic<U : __BuiltinRealType> + [__unsafeForceInlineEarly] + static Differential dmul(U a, Differential b) + { + if (b.hasValue) + return Optional<T.Differential>(T.dmul<U>(a, b.value)); + else + return b; + } +} + +__generic<T : IDifferentiablePtrType> +extension Optional<T> : IDifferentiablePtrType +{ + typedef Optional<T.Differential> Differential; +} + __generic<each T : IDifferentiable> extension Tuple<T> : IDifferentiable { diff --git a/source/slang/slang-ast-natural-layout.cpp b/source/slang/slang-ast-natural-layout.cpp index 8bfc5f8ce..f15dee1d1 100644 --- a/source/slang/slang-ast-natural-layout.cpp +++ b/source/slang/slang-ast-natural-layout.cpp @@ -4,6 +4,7 @@ #include "slang-ast-builder.h" // For BaseInfo +#include "slang-check-impl.h" #include "slang-compiler.h" namespace Slang @@ -165,6 +166,15 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) return size; } + else if (auto optionalType = as<OptionalType>(type)) + { + if (isNullableType(optionalType->getValueType())) + return calcSize(optionalType->getValueType()); + NaturalSize size = NaturalSize::makeEmpty(); + size.append(calcSize(m_astBuilder->getBoolType())); + size.append(calcSize(optionalType->getValueType())); + return size; + } else if (auto declRefType = as<DeclRefType>(type)) { if (const auto enumDeclRef = declRefType->getDeclRef().as<EnumDecl>()) diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index fcbc83673..41355597f 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -250,6 +250,13 @@ bool SemanticsVisitor::isCStyleType(Type* type, HashSet<Type*>& isVisit) as<PtrType>(type)) return cacheResult(true); + // Slang 2026 language fix: an interface type is not C-style. + if (isSlang2026OrLater(this)) + { + // TODO: some/dyn types are also not C-style. + if (isDeclRefTypeOf<InterfaceDecl>(type)) + return cacheResult(false); + } // A tuple type is C-style if all of its members are C-style. if (auto tupleType = as<TupleType>(type)) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5aff41988..e3b05ec00 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -34,7 +34,7 @@ static bool isAssociatedTypeDecl(Decl* decl) return false; } -static bool isSlang2026OrLater(SemanticsVisitor* visitor) +bool isSlang2026OrLater(SemanticsVisitor* visitor) { return visitor->getShared()->m_module->getModuleDecl()->languageVersion >= SLANG_LANGUAGE_VERSION_2026; @@ -1604,6 +1604,23 @@ EnumDecl* isEnumType(Type* type) return nullptr; } +bool isNullableType(Type* type) +{ + if (as<PtrTypeBase>(type)) + return true; + if (isDeclRefTypeOf<InterfaceDecl>(type)) + return true; + if (isDeclRefTypeOf<ClassDecl>(type)) + return true; + if (as<OptionalType>(type)) + return true; + if (as<RefTypeBase>(type)) + return true; + if (as<NativeStringType>(type)) + return true; + return false; +} + bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state) { if (state < DeclCheckState::DefinitionChecked) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 80436e68a..1cdebb115 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -42,6 +42,8 @@ bool isUnsafeForceInlineFunc(FunctionDeclBase* funcDecl); bool isUniformParameterType(Type* type); +bool isSlang2026OrLater(SemanticsVisitor* visitor); + /// Create a new component type based on `inComponentType`, but with all its requiremetns filled. RefPtr<ComponentType> fillRequirements(ComponentType* inComponentType); @@ -3115,6 +3117,10 @@ bool isUnsizedArrayType(Type* type); bool isInterfaceType(Type* type); +// Check if `type` is nullable. An `Optional<T>` will occupy the same space as `T`, if `T` +// is nullable. +bool isNullableType(Type* type); + EnumDecl* isEnumType(Type* type); DeclVisibility getDeclVisibility(Decl* decl); diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 6a713f412..a57e01a88 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -915,6 +915,51 @@ RefPtr<ComponentType> fillRequirements(ComponentType* inComponentType) return componentType; } +bool parseTypeConformanceArgString( + UnownedStringSlice optionString, + UnownedStringSlice& outTypeName, + UnownedStringSlice& outInterfaceName, + Index& outSequentialId) +{ + // The expected format for the type conformance argument is: + // `TypeName:InterfaceName[=SequentialId]` + // + // Where `TypeName` is the name of a concrete type, `InterfaceName` + // is the name of an interface type, and `SequentialId` is an optional + // integer that specifies a sequential ID for the conformance. + // + // If the string does not match this format, we will return false. + + outTypeName = UnownedStringSlice(); + outInterfaceName = UnownedStringSlice(); + outSequentialId = -1; + auto colonPos = optionString.indexOf(':'); + if (colonPos < 0) + { + // If there is no colon, then the string is invalid. + return false; + } + outTypeName = optionString.head(colonPos); + auto interfaceNameStart = colonPos + 1; + auto equalsPos = optionString.indexOf('='); + if (equalsPos < interfaceNameStart) + { + // If there is no equals sign, then the interface name goes to the end of the string. + outInterfaceName = optionString.tail(interfaceNameStart); + } + else + { + // If there is an equals sign, then the interface name goes up to that point. + outInterfaceName = + optionString.subString(interfaceNameStart, equalsPos - interfaceNameStart); + // The sequential ID is the part after the equals sign. + auto sequentialIdString = optionString.tail(equalsPos + 1); + if (SLANG_FAILED(StringUtil::parseInt(sequentialIdString, outSequentialId))) + return false; + } + return true; +} + /// Create a component type to represent the "global scope" of a compile request. /// /// This component type will include all the modules and their global @@ -965,6 +1010,85 @@ RefPtr<ComponentType> createUnspecializedGlobalComponentType(FrontEndCompileRequ CompositeComponentType::create(linkage, translationUnitComponentTypes); } + List<RefPtr<ComponentType>> conformanceComponents; + + // Find and include all type conformances specified through compiler options. + for (auto conformances : + compileRequest->optionSet.getArray(CompilerOptionName::TypeConformance)) + { + auto stringValue = conformances.stringValue.getUnownedSlice(); + UnownedStringSlice typeName, interfaceName; + Index sequentialId = -1; + if (!parseTypeConformanceArgString(stringValue, typeName, interfaceName, sequentialId)) + { + compileRequest->getSink()->diagnose( + SourceLoc(), + Diagnostics::invalidTypeConformanceOptionString, + stringValue); + continue; + } + auto concreteType = globalComponentType->getTypeFromString( + String(typeName).getBuffer(), + compileRequest->getSink()); + if (!concreteType) + { + compileRequest->getSink()->diagnose( + SourceLoc(), + Diagnostics::invalidTypeConformanceOptionNoType, + stringValue, + typeName); + continue; + } + auto interfaceType = globalComponentType->getTypeFromString( + String(interfaceName).getBuffer(), + compileRequest->getSink()); + if (!interfaceType) + { + compileRequest->getSink()->diagnose( + SourceLoc(), + Diagnostics::invalidTypeConformanceOptionNoType, + stringValue, + interfaceName); + continue; + } + ComPtr<slang::ITypeConformance> conformanceComponent; + ComPtr<ISlangBlob> diagnostics; + compileRequest->getLinkage()->createTypeConformanceComponentType( + (slang::TypeReflection*)concreteType, + (slang::TypeReflection*)interfaceType, + conformanceComponent.writeRef(), + sequentialId, + diagnostics.writeRef()); + if (!conformanceComponent) + { + // If we failed to create the conformance component, then + // we should report the diagnostics that were generated. + // + compileRequest->getSink()->diagnose( + SourceLoc(), + Diagnostics::cannotCreateTypeConformance, + stringValue); + if (diagnostics) + { + compileRequest->getSink()->diagnoseRaw( + Severity::Error, + UnownedStringSlice((char*)diagnostics->getBufferPointer())); + } + continue; + } + conformanceComponents.add(static_cast<TypeConformance*>(conformanceComponent.get())); + } + + if (conformanceComponents.getCount() > 0) + { + // If we found any type conformances, then we will + // create a composite component type that includes + // the global component type and the conformance components. + // + conformanceComponents.add(globalComponentType); + globalComponentType = CompositeComponentType::create(linkage, conformanceComponents); + } + return fillRequirements(globalComponentType); } diff --git a/source/slang/slang-compiler-options.cpp b/source/slang/slang-compiler-options.cpp index 5c17121cc..843e0e7cb 100644 --- a/source/slang/slang-compiler-options.cpp +++ b/source/slang/slang-compiler-options.cpp @@ -198,6 +198,7 @@ bool CompilerOptionSet::allowDuplicate(CompilerOptionName name) case CompilerOptionName::DownstreamArgs: case CompilerOptionName::VulkanBindShift: case CompilerOptionName::VulkanBindShiftAll: + case CompilerOptionName::TypeConformance: return true; } return false; diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index e06278599..8f7f7d49a 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -2192,6 +2192,11 @@ public: slang::TypeReflection* type, slang::TypeReflection* interfaceType, uint32_t* outId) override; + SLANG_NO_THROW SlangResult SLANG_MCALL getDynamicObjectRTTIBytes( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + uint32_t* outBytes, + uint32_t bufferSize) override; SLANG_NO_THROW SlangResult SLANG_MCALL createTypeConformanceComponentType( slang::TypeReflection* type, slang::TypeReflection* interfaceType, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 4aadfd78d..465602a33 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -249,6 +249,17 @@ DIAGNOSTIC( cannotMatchOutputFileToEntryPoint, "the output path '$0' is not associated with any entry point; a '-o' option for a compiled " "kernel must follow the '-entry' option for its corresponding entry point") +DIAGNOSTIC( + 71, + Error, + invalidTypeConformanceOptionString, + "syntax error in type conformance option '$0'.") +DIAGNOSTIC( + 72, + Error, + invalidTypeConformanceOptionNoType, + "invalid conformance option '$0', type '$0' is not found.") +DIAGNOSTIC(73, Error, cannotCreateTypeConformance, "cannot create type conformance '$0'.") DIAGNOSTIC( 80, diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 26b964d6e..0534b64d4 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -5331,7 +5331,8 @@ void CLikeSourceEmitter::computeEmitActions(IRModule* module, List<EmitAction>& // Skip resource types in this pass. if (isResourceType(inst->getDataType())) continue; - + if (as<IRInterfaceRequirementEntry>(inst)) + continue; ensureGlobalInst(&ctx, inst, EmitAction::Level::Definition); } } diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index a4362b912..20459c722 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -778,9 +778,6 @@ Result linkAndOptimizeIR( break; } - if (requiredLoweringPassSet.optionalType) - lowerOptionalType(irModule, sink); - switch (target) { case CodeGenTarget::CUDASource: @@ -792,20 +789,6 @@ Result linkAndOptimizeIR( break; } - switch (target) - { - case CodeGenTarget::CPPSource: - case CodeGenTarget::HostCPPSource: - { - lowerComInterfaces(irModule, artifactDesc.style, sink); - generateDllImportFuncs(codeGenContext->getTargetProgram(), irModule, sink); - generateDllExportFuncs(irModule, sink); - break; - } - default: - break; - } - #if 0 dumpIRIfEnabled(codeGenContext, irModule, "UNIONS DESUGARED"); #endif @@ -947,12 +930,6 @@ Result linkAndOptimizeIR( break; } - // Lower `Result<T,E>` types into ordinary struct types. This must happen - // after specialization, since otherwise incompatible copies of the lowered - // result structure are generated. - if (requiredLoweringPassSet.resultType) - lowerResultType(irModule, sink); - // Report checkpointing information if (codeGenContext->shouldReportCheckpointIntermediates()) { @@ -978,6 +955,29 @@ Result linkAndOptimizeIR( finalizeSpecialization(irModule); + // Lower `Result<T,E>` types into ordinary struct types. This must happen + // after specialization, since otherwise incompatible copies of the lowered + // result structure are generated. + if (requiredLoweringPassSet.resultType) + lowerResultType(irModule, sink); + + if (requiredLoweringPassSet.optionalType) + lowerOptionalType(irModule, sink); + + switch (target) + { + case CodeGenTarget::CPPSource: + case CodeGenTarget::HostCPPSource: + { + lowerComInterfaces(irModule, artifactDesc.style, sink); + generateDllImportFuncs(codeGenContext->getTargetProgram(), irModule, sink); + generateDllExportFuncs(irModule, sink); + break; + } + default: + break; + } + requiredLoweringPassSet = {}; calcRequiredLoweringPassSet(requiredLoweringPassSet, codeGenContext, irModule->getModuleInst()); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index df690a4e2..003790793 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1308,6 +1308,32 @@ InstPair ForwardDiffTranscriber::transcribeGetTupleElement(IRBuilder* builder, I return InstPair(primalGetElement, diffGetElement); } +InstPair ForwardDiffTranscriber::transcribeGetOptionalValue( + IRBuilder* builder, + IRInst* originalInst) +{ + IRInst* origBase = originalInst->getOperand(0); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType()); + + IRInst* primalGetOptionalVal = + builder->emitIntrinsicInst(primalType, originalInst->getOp(), 1, &primalBase); + + IRInst* diffGetOptionalVal = nullptr; + + if (auto diffType = differentiateType(builder, primalGetOptionalVal->getDataType())) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + diffGetOptionalVal = + builder->emitIntrinsicInst(diffType, originalInst->getOp(), 1, &diffBase); + } + } + + return InstPair(primalGetOptionalVal, diffGetOptionalVal); +} + InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst) { auto updateInst = as<IRUpdateElement>(originalInst); @@ -2020,6 +2046,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_MakeArray: case kIROp_MakeArrayFromElement: case kIROp_MakeTuple: + case kIROp_MakeOptionalValue: + case kIROp_MakeResultValue: case kIROp_MakeValuePack: case kIROp_BuiltinCast: return transcribeConstruct(builder, origInst); @@ -2063,6 +2091,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_GetTupleElement: return transcribeGetTupleElement(builder, origInst); + case kIROp_GetOptionalValue: + return transcribeGetOptionalValue(builder, origInst); case kIROp_ifElse: return transcribeIfElse(builder, as<IRIfElse>(origInst)); @@ -2197,6 +2227,12 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_MakeCoopVectorFromValuePack: case kIROp_GetCurrentStage: case kIROp_GetOffsetPtr: + case kIROp_IsNullExistential: + case kIROp_MakeResultError: + case kIROp_IsResultError: + case kIROp_GetResultError: + case kIROp_MakeOptionalNone: + case kIROp_OptionalHasValue: return transcribeNonDiffInst(builder, origInst); // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 09b3f14b8..1bdbb01c8 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -72,6 +72,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeGetTupleElement(IRBuilder* builder, IRInst* origInst); + InstPair transcribeGetOptionalValue(IRBuilder* builder, IRInst* originalInst); + InstPair transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst); InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index d67d75997..d3d5d72a9 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -390,6 +390,13 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy diffTypeList.getBuffer()); } + case kIROp_OptionalType: + { + auto origOptionalType = as<IROptionalType>(primalType); + auto diffValueType = differentiateType(builder, origOptionalType->getValueType()); + return builder->getOptionalType(diffValueType); + } + default: return (IRType*)maybeCloneForPrimalInst( builder, diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 606428efe..09f70725a 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1476,6 +1476,9 @@ struct DiffTransposePass case kIROp_GetElement: return transposeGetElement(builder, as<IRGetElement>(fwdInst), revValue); + case kIROp_GetOptionalValue: + return transposeGetOptionalValue(builder, as<IRGetOptionalValue>(fwdInst), revValue); + case kIROp_Return: return transposeReturn(builder, as<IRReturn>(fwdInst), revValue); @@ -1531,7 +1534,8 @@ struct DiffTransposePass return transposeMakeTuple(builder, fwdInst, revValue); case kIROp_MakeArrayFromElement: return transposeMakeArrayFromElement(builder, fwdInst, revValue); - + case kIROp_MakeOptionalValue: + return transposeMakeOptionalValue(builder, fwdInst, revValue); case kIROp_UpdateElement: return transposeUpdateElement(builder, fwdInst, revValue); @@ -1673,6 +1677,20 @@ struct DiffTransposePass fwdGetElement))); } + TranspositionResult transposeGetOptionalValue( + IRBuilder* builder, + IRGetOptionalValue* fwdGetOptionalValue, + IRInst* revValue) + { + // dP = GetOptionalValue(dVal) -> dVal = MakeOptionalValue(dP) + auto optionalVal = fwdGetOptionalValue->getOperand(0); + return TranspositionResult(List<RevGradient>(RevGradient( + RevGradient::Flavor::Simple, + fwdGetOptionalValue->getOperand(0), + builder->emitMakeOptionalValue(optionalVal->getDataType(), revValue), + fwdGetOptionalValue))); + } + TranspositionResult transposeMakePair( IRBuilder*, IRMakeDifferentialPair* fwdMakePair, @@ -1982,6 +2000,29 @@ struct DiffTransposePass return TranspositionResult(gradients); } + TranspositionResult transposeMakeOptionalValue( + IRBuilder* builder, + IRInst* fwdMakeOptionalValue, + IRInst* revValue) + { + List<RevGradient> gradients; + + auto gradAtField = builder->emitGetOptionalValue(revValue); + auto diffZero = diffTypeContext.emitDZeroOfDiffInstType( + builder, + tryGetPrimalTypeFromDiffInst(fwdMakeOptionalValue->getOperand(0))); + IRInst* selectArgs[] = {builder->emitOptionalHasValue(revValue), gradAtField, diffZero}; + builder->emitIntrinsicInst(gradAtField->getDataType(), kIROp_Select, 3, selectArgs); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeOptionalValue->getOperand(0), + gradAtField, + fwdMakeOptionalValue)); + + // (A = MakeOptionalValue(F)) -> [(dF += dA.hasValue?dA.value:dzero)] + return TranspositionResult(gradients); + } + TranspositionResult transposeMakeStruct( IRBuilder* builder, IRInst* fwdMakeStruct, diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index c8dc3b480..133c257a8 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1754,7 +1754,19 @@ IRType* DifferentiableTypeConformanceContext::differentiateType( (UInt)diffTypeList.getCount(), diffTypeList.getBuffer()); } - + case kIROp_OptionalType: + { + auto primalOptionalType = as<IROptionalType>(primalType); + if (auto diffElementType = + differentiateType(builder, primalOptionalType->getValueType())) + { + return builder->getOptionalType(diffElementType); + } + else + { + return nullptr; + } + } default: return (IRType*)getDifferentialForType(builder, (IRType*)primalType); } diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index befd1f98a..970f490c9 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -320,6 +320,7 @@ struct DifferentiableTypeConformanceContext } case kIROp_TupleType: case kIROp_TypePack: + case kIROp_OptionalType: { return differentiateType(builder, origType); } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 3b45d46b3..7a281bac4 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1225,7 +1225,7 @@ INST(GetValueFromBoundInterface, getValueFromBoundInterface, 1, 0) INST(ExtractExistentialValue, extractExistentialValue, 1, 0) INST(ExtractExistentialType, extractExistentialType, 1, HOISTABLE) INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, HOISTABLE) - +INST(IsNullExistential, isNullExistential, 1, 0) INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0) INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 2fff4e451..b5c1a6475 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3343,6 +3343,12 @@ struct IRExtractExistentialWitnessTable : IRInst IR_LEAF_ISA(ExtractExistentialWitnessTable); }; +struct IRIsNullExistential : IRInst +{ + IR_LEAF_ISA(IsNullExistential); +}; + + /* Base class for instructions that track liveness */ struct IRLiveRangeMarker : IRInst { @@ -4059,6 +4065,9 @@ public: /// Given an existential value, extract the underlying "real" type IRType* emitExtractExistentialType(IRInst* existentialValue); + /// Given an existential value, return if it is empty/null. + IRInst* emitIsNullExistential(IRInst* existentialValue); + /// Given an existential value, extract the witness table showing how the value conforms to the /// existential type. IRInst* emitExtractExistentialWitnessTable(IRInst* existentialValue); diff --git a/source/slang/slang-ir-lower-existential.cpp b/source/slang/slang-ir-lower-existential.cpp index c4ee6e6fc..076ed57bd 100644 --- a/source/slang/slang-ir-lower-existential.cpp +++ b/source/slang/slang-ir-lower-existential.cpp @@ -131,6 +131,19 @@ struct ExistentialLoweringContext processExtractExistentialElement(inst, 2); } + void processIsNullExistential(IRIsNullExistential* inst) + { + IRBuilder builder(sharedContext->module); + builder.setInsertBefore(inst); + + auto rttiElement = extractTupleElement(&builder, inst->getOperand(0), 0); + auto isNull = builder.emitNeq( + builder.emitGetElement(builder.getUIntType(), rttiElement, 0), + builder.getIntValue(builder.getUIntType(), 0)); + inst->replaceUsesWith(isNull); + inst->removeAndDeallocate(); + } + void processExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) { processExtractExistentialElement(inst, 1); @@ -261,6 +274,10 @@ struct ExistentialLoweringContext { processExtractExistentialWitnessTable(extractExistentialWitnessTable); } + else if (auto isNullExistential = as<IRIsNullExistential>(inst)) + { + processIsNullExistential(isNullExistential); + } } void processModule() diff --git a/source/slang/slang-ir-lower-optional-type.cpp b/source/slang/slang-ir-lower-optional-type.cpp index 5c9dcd722..1f9f398d3 100644 --- a/source/slang/slang-ir-lower-optional-type.cpp +++ b/source/slang/slang-ir-lower-optional-type.cpp @@ -8,6 +8,13 @@ namespace Slang { +enum LoweredOptionalTypeKind +{ + Struct, + PtrValue, + ExistentialValue, +}; + struct OptionalTypeLoweringContext { IRModule* module; @@ -16,10 +23,6 @@ struct OptionalTypeLoweringContext InstWorkList workList; InstHashSet workListSet; - IRGeneric* genericOptionalStructType = nullptr; - IRStructKey* valueKey = nullptr; - IRStructKey* hasValueKey = nullptr; - OptionalTypeLoweringContext(IRModule* inModule) : module(inModule), workList(inModule), workListSet(inModule) { @@ -30,6 +33,9 @@ struct OptionalTypeLoweringContext IRType* optionalType = nullptr; IRType* valueType = nullptr; IRType* loweredType = nullptr; + IRStructKey* hasValueKey = nullptr; + IRStructKey* valueKey = nullptr; + LoweredOptionalTypeKind kind = LoweredOptionalTypeKind::Struct; }; Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> mapLoweredTypeToOptionalTypeInfo; Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> loweredOptionalTypes; @@ -42,37 +48,29 @@ struct OptionalTypeLoweringContext return type; } - IRInst* getOrCreateGenericOptionalStruct() + IRInst* createOptionalStruct(IRType* type, LoweredOptionalTypeInfo* info) { - if (genericOptionalStructType) - return genericOptionalStructType; IRBuilder builder(module); builder.setInsertInto(module->getModuleInst()); - valueKey = builder.createStructKey(); - builder.addNameHintDecoration(valueKey, UnownedStringSlice("value")); - hasValueKey = builder.createStructKey(); - builder.addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue")); - - genericOptionalStructType = builder.emitGeneric(); - builder.addNameHintDecoration( - genericOptionalStructType, - UnownedStringSlice("_slang_Optional")); + info->valueKey = builder.createStructKey(); + builder.addNameHintDecoration(info->valueKey, UnownedStringSlice("value")); + info->hasValueKey = builder.createStructKey(); + builder.addNameHintDecoration(info->hasValueKey, UnownedStringSlice("hasValue")); - builder.setInsertInto(genericOptionalStructType); - auto block = builder.emitBlock(); - auto typeParam = builder.emitParam(builder.getTypeKind()); auto structType = builder.createStructType(); - builder.addNameHintDecoration(structType, UnownedStringSlice("_slang_Optional")); - builder.createStructField(structType, valueKey, (IRType*)typeParam); - builder.createStructField(structType, hasValueKey, builder.getBoolType()); - builder.setInsertInto(block); - builder.emitReturn(structType); - genericOptionalStructType->setFullType(builder.getTypeKind()); - return genericOptionalStructType; + StringBuilder sb; + sb << "_slang_Optional_"; + getTypeNameHint(sb, type); + builder.addNameHintDecoration(structType, sb.getUnownedSlice()); + builder.createStructField(structType, info->valueKey, type); + builder.createStructField(structType, info->hasValueKey, builder.getBoolType()); + + info->kind = LoweredOptionalTypeKind::Struct; + return structType; } - bool typeHasNullValue(IRInst* type) + bool typeHasNullValue(IRInst* type, LoweredOptionalTypeKind& outKind) { switch (type->getOp()) { @@ -81,21 +79,25 @@ struct OptionalTypeLoweringContext case kIROp_NativeStringType: case kIROp_PtrType: case kIROp_ClassType: + outKind = LoweredOptionalTypeKind::PtrValue; return true; case kIROp_InterfaceType: - return isComInterfaceType((IRType*)type); + if (isComInterfaceType((IRType*)type)) + outKind = LoweredOptionalTypeKind::PtrValue; + else + outKind = LoweredOptionalTypeKind::ExistentialValue; + return true; default: return false; } } - LoweredOptionalTypeInfo* getLoweredOptionalType(IRBuilder* builder, IRInst* type) + LoweredOptionalTypeInfo* getLoweredOptionalType(IRBuilder*, IRInst* type) { if (auto loweredInfo = loweredOptionalTypes.tryGetValue(type)) return loweredInfo->Ptr(); if (auto loweredInfo = mapLoweredTypeToOptionalTypeInfo.tryGetValue(type)) return loweredInfo->Ptr(); - if (!type) return nullptr; if (type->getOp() != kIROp_OptionalType) @@ -104,19 +106,21 @@ struct OptionalTypeLoweringContext RefPtr<LoweredOptionalTypeInfo> info = new LoweredOptionalTypeInfo(); auto optionalType = cast<IROptionalType>(type); auto valueType = optionalType->getValueType(); + while (auto valueOptionalType = as<IROptionalType>(valueType)) + { + // If the value type is also an Optional, we need to keep lowering it. + valueType = valueOptionalType->getValueType(); + } + info->optionalType = (IRType*)type; info->valueType = valueType; - if (typeHasNullValue(valueType)) + if (typeHasNullValue(valueType, info->kind)) { info->loweredType = valueType; } else { - auto genericType = getOrCreateGenericOptionalStruct(); - IRInst* args[] = {valueType}; - auto specializedType = - builder->emitSpecializeInst(builder->getTypeKind(), genericType, 1, args); - info->loweredType = (IRType*)specializedType; + info->loweredType = (IRType*)createOptionalStruct(valueType, info); } mapLoweredTypeToOptionalTypeInfo[info->loweredType] = info; loweredOptionalTypes[type] = info; @@ -171,6 +175,12 @@ struct OptionalTypeLoweringContext inst->replaceUsesWith(makeStruct); inst->removeAndDeallocate(); } + else if (info->kind == LoweredOptionalTypeKind::ExistentialValue) + { + auto zero = builder->emitDefaultConstruct(info->loweredType); + inst->replaceUsesWith(zero); + inst->removeAndDeallocate(); + } else { inst->replaceUsesWith(builder->getNullPtrValue(info->valueType)); @@ -183,13 +193,20 @@ struct OptionalTypeLoweringContext auto loweredOptionalTypeInfo = getLoweredOptionalType(builder, optionalInst->getDataType()); SLANG_ASSERT(loweredOptionalTypeInfo); IRInst* result = nullptr; - if (loweredOptionalTypeInfo->loweredType != loweredOptionalTypeInfo->valueType) - { - result = builder->emitFieldExtract(builder->getBoolType(), optionalInst, hasValueKey); - } - else + switch (loweredOptionalTypeInfo->kind) { + case LoweredOptionalTypeKind::Struct: + result = builder->emitFieldExtract( + builder->getBoolType(), + optionalInst, + loweredOptionalTypeInfo->hasValueKey); + break; + case LoweredOptionalTypeKind::PtrValue: result = builder->emitCastPtrToBool(optionalInst); + break; + case LoweredOptionalTypeKind::ExistentialValue: + result = builder->emitIsNullExistential(optionalInst); + break; } return result; } @@ -214,11 +231,13 @@ struct OptionalTypeLoweringContext auto base = inst->getOptionalOperand(); auto loweredOptionalTypeInfo = getLoweredOptionalType(builder, base->getDataType()); - if (loweredOptionalTypeInfo->loweredType != loweredOptionalTypeInfo->valueType) + if (loweredOptionalTypeInfo->kind == LoweredOptionalTypeKind::Struct) { SLANG_ASSERT(loweredOptionalTypeInfo); - auto getElement = - builder->emitFieldExtract(loweredOptionalTypeInfo->valueType, base, valueKey); + auto getElement = builder->emitFieldExtract( + loweredOptionalTypeInfo->valueType, + base, + loweredOptionalTypeInfo->valueKey); inst->replaceUsesWith(getElement); } else diff --git a/source/slang/slang-ir-marshal-native-call.cpp b/source/slang/slang-ir-marshal-native-call.cpp index 1d3f04318..f19bf2168 100644 --- a/source/slang/slang-ir-marshal-native-call.cpp +++ b/source/slang/slang-ir-marshal-native-call.cpp @@ -17,6 +17,17 @@ IRType* NativeCallMarshallingContext::getNativeType(IRBuilder& builder, IRType* return builder.getNativePtrType(type); case kIROp_ComPtrType: return builder.getNativePtrType((IRType*)as<IRComPtrType>(type)->getOperand(0)); + case kIROp_ArrayType: + case kIROp_UnsizedArrayType: + { + auto arrayType = as<IRArrayType>(type); + auto elementType = arrayType->getElementType(); + auto nativeElementType = getNativeType(builder, elementType); + return builder.getArrayTypeBase( + elementType->getOp(), + nativeElementType, + arrayType->getElementCount()); + } case kIROp_InOutType: case kIROp_RefType: case kIROp_ConstRefType: diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 28e98fdb6..7f3ff1d68 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -802,13 +802,33 @@ struct PeepholeContext : InstPassBase { if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue) { - IRBuilder builder(module); - IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); - builder.setInsertBefore(inst); - auto trueVal = builder.getBoolValue(true); - inst->replaceUsesWith(trueVal); - maybeRemoveOldInst(inst); - changed = true; + auto getHasValue = as<IROptionalHasValue>(inst); + auto optionalType = + as<IROptionalType>(getHasValue->getOptionalOperand()->getDataType()); + if (!optionalType) + break; + if (as<IROptionalType>(optionalType->getValueType())) + { + // HasValue(o : Optional<Optional<T>>) ==> HasValue(o.value). + IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); + auto newVal = builder.emitOptionalHasValue( + builder.emitGetOptionalValue(getHasValue->getOptionalOperand())); + inst->replaceUsesWith(newVal); + maybeRemoveOldInst(inst); + changed = true; + } + else + { + IRBuilder builder(module); + IRBuilderSourceLocRAII srcLocRAII(&builder, inst->sourceLoc); + builder.setInsertBefore(inst); + auto trueVal = builder.getBoolValue(true); + inst->replaceUsesWith(trueVal); + maybeRemoveOldInst(inst); + changed = true; + } } else if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalNone) { diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index f571ec20b..e66ad69ce 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3515,6 +3515,14 @@ IRInst* IRBuilder::emitExtractExistentialValue(IRType* type, IRInst* existential return inst; } +IRInst* IRBuilder::emitIsNullExistential(IRInst* existentialValue) +{ + auto inst = + createInst<IRInst>(this, kIROp_IsNullExistential, getBoolType(), 1, &existentialValue); + addInst(inst); + return inst; +} + IRType* IRBuilder::emitExtractExistentialType(IRInst* existentialValue) { auto type = getTypeKind(); @@ -8648,6 +8656,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialValue: case kIROp_ExtractExistentialWitnessTable: + case kIROp_IsNullExistential: case kIROp_WrapExistential: case kIROp_BuiltinCast: case kIROp_BitCast: diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index 3ce107135..c269b10c4 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -530,6 +530,10 @@ void initCommandOptions(CommandOptions& options) nullptr, "Preserve all resource parameters in the output code, even if they are not used by the " "shader."}, + {OptionKind::TypeConformance, + "-conformance", + "-conformance <typeName>:<interfaceName>[=<sequentialID>]", + "Include additional type conformance during linking for dynamic dispatch."}, {OptionKind::EmitReflectionJSON, "-reflection-json", "reflection-json <path>", @@ -2736,6 +2740,17 @@ SlangResult OptionsParser::_parse(int argc, char const* const* argv) m_compileRequest->addSearchPath(String(slice).getBuffer()); break; } + case OptionKind::TypeConformance: + { + if (!m_reader.hasArg()) + break; + CommandLineArg operand; + SLANG_RETURN_ON_FAIL(m_reader.expectArg(operand)); + auto unquoted = + StringEscapeUtil::maybeUnquoteCommandLineArg(operand.value.getUnownedSlice()); + linkage->m_optionSet.add(OptionKind::TypeConformance, unquoted); + break; + } case OptionKind::Output: { // diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 398aab517..431cf6669 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -6124,11 +6124,6 @@ Stmt* Parser::parseIfLetStatement() tempVarDecl->nameAndLoc = NameLoc(getName(this, "$OptVar"), identifierToken.loc); tempVarDecl->initExpr = initExpr; AddMember(currentScope->containerDecl, tempVarDecl); - if (semanticsVisitor) - semanticsVisitor->ensureDecl( - (Decl*)tempVarDecl, - DeclCheckState::DefinitionChecked, - nullptr); DeclStmt* tmpVarDeclStmt = astBuilder->create<DeclStmt>(); FillPosition(tmpVarDeclStmt); diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 2c72b61de..b2171823b 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -4946,6 +4946,8 @@ static TypeLayoutResult _createTypeLayout(TypeLayoutContext& context, Type* type else if (auto optionalType = as<OptionalType>(type)) { // OptionalType should be laid out the same way as Tuple<T, bool>. + if (isNullableType(optionalType->getValueType())) + return _createTypeLayout(context, optionalType->getValueType()); Array<Type*, 2> types = makeArray(optionalType->getValueType(), context.astBuilder->getBoolType()); auto tupleType = context.astBuilder->getTupleType(types.getView()); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 9aa8c56a7..065b9de93 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -2029,6 +2029,31 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getTypeConformanceWitnessSequent return SLANG_OK; } +SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::getDynamicObjectRTTIBytes( + slang::TypeReflection* type, + slang::TypeReflection* interfaceType, + uint32_t* outBuffer, + uint32_t bufferSize) +{ + // Slang RTTI header format: + // byte 0-7: pointer to RTTI struct describing the type. (not used for now, set to 1 for valid + // types, and 0 to represent null). + // byte 8-11: 32-bit sequential ID of the type conformance witness. + // byte 12-15: unused. + + if (bufferSize < 16) + return SLANG_E_BUFFER_TOO_SMALL; + + SLANG_AST_BUILDER_RAII(getASTBuilder()); + + SLANG_RETURN_ON_FAIL(getTypeConformanceWitnessSequentialID(type, interfaceType, outBuffer + 2)); + + // Make the RTTI part non zero. + outBuffer[0] = 1; + + return SLANG_OK; +} + SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentType( slang::TypeReflection* type, slang::TypeReflection* interfaceType, |
