diff options
Diffstat (limited to 'source/slang')
23 files changed, 487 insertions, 25 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 37af3ef5a..03cdc9ee2 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -4298,16 +4298,25 @@ matrix<T,N,M> WaveMaskReadLaneFirst(WaveMask mask, matrix<T,N,M> expr); __generic<T : __BuiltinType> __target_intrinsic(hlsl, "WaveMatch($1).x") +__glsl_extension(GL_NV_shader_subgroup_partitioned) +__spirv_version(1.3) +__target_intrinsic(glsl, "subgroupPartitionNV($1).x") __cuda_sm_version(7.0) __target_intrinsic(cuda, "_waveMatchScalar($0, $1).x") WaveMask WaveMaskMatch(WaveMask mask, T value); __generic<T : __BuiltinType, let N : int> __target_intrinsic(hlsl, "WaveMatch($1).x") +__glsl_extension(GL_NV_shader_subgroup_partitioned) +__spirv_version(1.3) +__target_intrinsic(glsl, "subgroupPartitionNV($1).x") __cuda_sm_version(7.0) __target_intrinsic(cuda, "_waveMatchMultiple($0, $1)") WaveMask WaveMaskMatch(WaveMask mask, vector<T,N> value); __generic<T : __BuiltinType, let N : int, let M : int> __target_intrinsic(hlsl, "WaveMatch($1).x") +__glsl_extension(GL_NV_shader_subgroup_partitioned) +__spirv_version(1.3) +__target_intrinsic(glsl, "subgroupPartitionNV($1).x") __cuda_sm_version(7.0) __target_intrinsic(cuda, "_waveMatchMultiple($0, $1)") WaveMask WaveMaskMatch(WaveMask mask, matrix<T,N,M> value); diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 3ab2de3d6..56861949f 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -428,6 +428,17 @@ Val* ASTBuilder::getNoDiffModifierVal() return getOrCreate<NoDiffModifierVal>(); } +Type* ASTBuilder::getFuncType(List<Type*> parameters, Type* result) +{ + auto errorType = getOrCreate<BottomType>(); + return getOrCreate<FuncType>(parameters, result, errorType); +} + +Type* ASTBuilder::getTupleType(List<Type*>& types) +{ + return getOrCreate<TupleType>(types); +} + TypeType* ASTBuilder::getTypeType(Type* type) { return getOrCreate<TypeType>(type); diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index c39293914..1ef56894b 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -151,6 +151,25 @@ public: /// no need for additional state. Dictionary<NodeDesc, NodeBase*> m_cachedNodes; + template<int N> + static void addOrAppendToNodeList(ShortList<NodeOperand, N>&) + {} + + template<int N, typename T, typename... Ts> + static void addOrAppendToNodeList(ShortList<NodeOperand, N>& list, T t, Ts... ts) + { + list.add(t); + addOrAppendToNodeList(list, ts...); + } + + template<int N, typename T, typename... Ts> + static void addOrAppendToNodeList(ShortList<NodeOperand, N>& list, const List<T>& l, Ts... ts ) + { + for(auto t : l) + list.add(t); + addOrAppendToNodeList(list, ts...); + } + public: // For compile time check to see if thing being constructed is an AST type @@ -174,11 +193,11 @@ public: } template<typename T, typename... TArgs> - T* create(TArgs... args) + T* create(TArgs&&... args) { auto alloced = m_arena.allocate(sizeof(T)); memset(alloced, 0, sizeof(T)); - return _initAndAdd(new (alloced) T(args...)); + return _initAndAdd(new (alloced) T(std::forward<TArgs>(args)...)); } template<typename T, typename ... TArgs> @@ -187,7 +206,7 @@ public: SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); NodeDesc desc; desc.type = T::kType; - addToList(desc.operands, args...); + addOrAppendToNodeList(desc.operands, args...); return (T*)_getOrCreateImpl(desc, [&]() { return create<T>(args...); @@ -210,7 +229,7 @@ public: SLANG_COMPILE_TIME_ASSERT(IsValidType<T>::Value); NodeDesc desc; desc.type = T::kType; - addToList(desc.operands, args...); + addOrAppendToNodeList(desc.operands, args...); return (T*)_getOrCreateImpl(desc, [&]() { return create<T>(); @@ -367,6 +386,10 @@ public: Val* getSNormModifierVal(); Val* getNoDiffModifierVal(); + Type* getTupleType(List<Type*>& types); + + Type* getFuncType(List<Type*> parameters, Type* result); + TypeType* getTypeType(Type* type); /// Helpers to get type info from the SharedASTBuilder diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 0a875fb50..da213e8d4 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -545,6 +545,22 @@ class PointerTypeExpr : public Expr TypeExp base; }; + /// A type expression that represents a function type, e.g. (bool, int) -> float +class FuncTypeExpr : public Expr +{ + SLANG_AST_CLASS(FuncTypeExpr); + + List<TypeExp> parameters; + TypeExp result; +}; + +class TupleTypeExpr : public Expr +{ + SLANG_AST_CLASS(TupleTypeExpr); + + List<TypeExp> members; +}; + /// An expression that applies a generic to arguments for some, /// but not all, of its explicit parameters. /// diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index b23f7c6ca..ea3db6937 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -236,6 +236,19 @@ struct ASTIterator iterator->maybeDispatchCallback(expr); dispatchIfNotNull(expr->base.exp); } + void visitFuncTypeExpr(FuncTypeExpr* expr) + { + iterator->maybeDispatchCallback(expr); + for(const auto& t : expr->parameters) + dispatchIfNotNull(t.exp); + dispatchIfNotNull(expr->result.exp); + } + void visitTupleTypeExpr(TupleTypeExpr* expr) + { + iterator->maybeDispatchCallback(expr); + for(auto t : expr->members) + dispatchIfNotNull(t.exp); + } void visitPointerTypeExpr(PointerTypeExpr* expr) { iterator->maybeDispatchCallback(expr); diff --git a/source/slang/slang-ast-reflect.h b/source/slang/slang-ast-reflect.h index 61711b940..9bce74587 100644 --- a/source/slang/slang-ast-reflect.h +++ b/source/slang/slang-ast-reflect.h @@ -13,7 +13,7 @@ NAME() = default; \ public: \ typedef NAME This; \ - static const ASTNodeType kType = ASTNodeType::NAME; \ + static constexpr ASTNodeType kType = ASTNodeType::NAME; \ static const ReflectClassInfo kReflectClassInfo; \ SLANG_FORCE_INLINE static bool isDerivedFrom(ASTNodeType type) { return int(type) >= int(kType) && int(type) <= int(ASTNodeType::LAST); } \ SLANG_CLASS_REFLECT_SUPER_##TYPE(SUPER) \ diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 27cc7800b..90ed1b8e4 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -521,8 +521,8 @@ ParameterDirection FuncType::getParamDirection(Index index) void FuncType::_toTextOverride(StringBuilder& out) { - out << toSlice("("); Index paramCount = getParamCount(); + out << toSlice("("); for (Index pp = 0; pp < paramCount; ++pp) { if (pp != 0) @@ -531,7 +531,7 @@ void FuncType::_toTextOverride(StringBuilder& out) } out << getParamType(pp); } - out << toSlice(") -> ") << getResultType(); + out << ") -> " << getResultType(); if (!getErrorType()->equals(getASTBuilder()->getBottomType())) { @@ -634,6 +634,77 @@ HashCode FuncType::_getHashCodeOverride() return hashCode; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TupleType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +void TupleType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("("); + for (Index pp = 0; pp < memberTypes.getCount(); ++pp) + { + if (pp != 0) + out << toSlice(", "); + out << memberTypes[pp]; + } + out << toSlice(")"); +} + +bool TupleType::_equalsImplOverride(Type * type) +{ + if (const auto other = as<TupleType>(type)) + { + auto paramCount = memberTypes.getCount(); + auto otherParamCount = other->memberTypes.getCount(); + if (paramCount != otherParamCount) + return false; + + for (Index i = 0; i < memberTypes.getCount(); ++i) + { + if(!memberTypes[i]->equals(other->memberTypes[i])) + return false; + } + + return true; + } + return false; +} + +Val* TupleType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + // just recurse into the members + List<Type*> substMemberTypes; + for (auto m : memberTypes) + substMemberTypes.add(as<Type>(m->substituteImpl(astBuilder, subst, &diff))); + + // early exit for no change... + if (!diff) + return this; + + (*ioDiff)++; + return astBuilder->create<TupleType>(std::move(substMemberTypes)); +} + +Type* TupleType::_createCanonicalTypeOverride() +{ + // member types + List<Type*> canMemberTypes; + for (auto m : memberTypes) + { + canMemberTypes.add(m->getCanonicalType()); + } + + return getASTBuilder()->create<TupleType>(std::move(canMemberTypes)); +} + +HashCode TupleType::_getHashCodeOverride() +{ + HashCode hashCode = Slang::getHashCode(kType); + for(auto m : memberTypes) + hashCode = combineHash(hashCode, m->getHashCode()); + return hashCode; +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ExtractExistentialType::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 67288a59d..7da7ecc88 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -664,6 +664,15 @@ class FuncType : public Type { SLANG_AST_CLASS(FuncType) + // Construct a unary function + FuncType(Type* paramType, Type* resultType, Type* errorType) + : paramTypes{{paramType}}, resultType{resultType}, errorType{errorType} + {} + + FuncType(List<Type*> parameters, Type* result, Type* error) + : paramTypes(std::move(parameters)), resultType(result), errorType(error) + {} + // TODO: We may want to preserve parameter names // in the list here, just so that we can print // out friendly names when printing a function @@ -689,6 +698,29 @@ class FuncType : public Type HashCode _getHashCodeOverride(); }; +// A tuple is a product of its member types +class TupleType : public Type +{ + SLANG_AST_CLASS(TupleType) + + // Construct a unary tupletion + TupleType(List<Type*> memberTypes) + : memberTypes(std::move(memberTypes)) + {} + + auto getMemberCount() { return memberTypes.getCount(); } const + auto& getMember(Index i) { return memberTypes[i]; } + + List<Type*> memberTypes; + + // Overrides should be public so base classes can access + void _toTextOverride(StringBuilder& out); + Type* _createCanonicalTypeOverride(); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); +}; + // The "type" of an expression that names a generic declaration. class GenericDeclRefType : public Type { diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 12988ebcf..cdffcf004 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -741,6 +741,20 @@ namespace Slang return true; } + } else if(auto fstFunType = as<FuncType>(fst)) + { + if (auto sndFunType = as<FuncType>(snd)) + { + const Index numParams = fstFunType->paramTypes.getCount(); + if(numParams != sndFunType->paramTypes.getCount()) + return false; + for(Index i = 0; i < numParams; ++i) + { + if(!TryUnifyTypes(constraints, fstFunType->paramTypes[i], sndFunType->paramTypes[i])) + return false; + } + return TryUnifyTypes(constraints, fstFunType->resultType, sndFunType->resultType); + } } return false; diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index d1231c72a..639ae7939 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -951,13 +951,13 @@ namespace Slang overloadContext.disallowNestedConversions = true; overloadContext.argCount = 1; overloadContext.argTypes = &fromType; + overloadContext.args = &fromExpr; overloadContext.originalExpr = nullptr; if(fromExpr) { overloadContext.loc = fromExpr->loc; overloadContext.funcLoc = fromExpr->loc; - overloadContext.args = &fromExpr; } overloadContext.baseExpr = nullptr; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 7b8be7a33..a91ec1e98 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -3677,4 +3677,43 @@ namespace Slang } } + Expr* SemanticsExprVisitor::visitFuncTypeExpr(FuncTypeExpr* expr) + { + // The input and output to a function type must both be types + for(auto& t : expr->parameters) + t = CheckProperType(t); + expr->result = CheckProperType(expr->result); + + // TODO: Kind checking? Where are we stopping someone passing + // constraints around as value-inhabitable types + + // The result of this expression is a `FuncType`, which we need + // to wrap in a `TypeType` to indicate that the result is the type + // itself and not a value of that type. + List<Type*> types; + types.reserve(expr->parameters.getCount()); + for(const auto& t : expr->parameters) + types.add(t.type); + auto funcType = m_astBuilder->getFuncType(std::move(types), expr->result.type); + expr->type = m_astBuilder->getTypeType(funcType); + + return expr; + } + + Expr* SemanticsExprVisitor::visitTupleTypeExpr(TupleTypeExpr* expr) + { + // All tuple members must be types + for(auto& t : expr->members) + t = CheckProperType(t); + + // As in the other cases above, wrap in TypeType + List<Type*> types; + types.reserve(expr->members.getCount()); + for(auto t : expr->members) + types.add(t.type); + auto tupleType = m_astBuilder->getTupleType(types); + expr->type = m_astBuilder->getTypeType(tupleType); + + return expr; + } } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 2a78fa999..efc3989e2 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1528,7 +1528,7 @@ namespace Slang SourceLoc loc; // The original expression (if any) that triggered things - Expr* originalExpr = nullptr; + AppExprBase* originalExpr = nullptr; // Source location of the "function" part of the expression, if any SourceLoc funcLoc; @@ -1695,6 +1695,11 @@ namespace Slang FuncType* /*funcType*/, OverloadResolveContext& /*context*/); + void AddFuncExprOverloadCandidate( + FuncType* funcType, + OverloadResolveContext& context, + Expr* expr); + // Add a candidate callee for overload resolution, based on // calling a particular `ConstructorDecl`. void AddCtorOverloadCandidate( @@ -1966,6 +1971,8 @@ namespace Slang Expr* visitAndTypeExpr(AndTypeExpr* expr); Expr* visitPointerTypeExpr(PointerTypeExpr* expr); Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr); + Expr* visitFuncTypeExpr(FuncTypeExpr* expr); + Expr* visitTupleTypeExpr(TupleTypeExpr* expr); Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index f2947a55d..5160f3c6f 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -412,6 +412,8 @@ namespace Slang if (context.mode == OverloadResolveContext::Mode::JustTrying) { + SLANG_ASSERT(argType); + ConversionCost cost = kConversionCost_None; if( context.disallowNestedConversions ) { @@ -1138,6 +1140,21 @@ namespace Slang AddOverloadCandidate(context, candidate); } + void SemanticsVisitor::AddFuncExprOverloadCandidate( + FuncType* funcType, + OverloadResolveContext& context, + Expr* expr) + { + SLANG_ASSERT(expr); + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + candidate.funcType = funcType; + candidate.resultType = funcType->getResultType(); + candidate.exprVal = expr; + + AddOverloadCandidate(context, candidate); + } + void SemanticsVisitor::AddCtorOverloadCandidate( LookupResultItem typeItem, Type* type, @@ -1432,9 +1449,23 @@ namespace Slang auto type = DeclRefType::create(m_astBuilder, genericTypeParamDeclRef); AddTypeOverloadCandidates(type, context); } + else if( auto localDeclRef = item.declRef.as<ParamDecl>() ) + { + // We could probably be broader than just parameters here + // eventually. + // Limit it for now though to make the specialization easier + ensureDecl(localDeclRef, DeclCheckState::CanUseFuncSignature); + const auto type = localDeclRef.getDecl()->getType(); + // We can only add overload candidates if this is known to be a function + if(const auto funType = as<FuncType>(type)) + AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr); + else + return; + } else { // TODO(tfoley): any other cases needed here? + return; } } @@ -1671,6 +1702,16 @@ namespace Slang { if (IsErrorExpr(arg)) return CreateErrorExpr(expr); + + // If this argument is itself an overloaded value without a type + // then we can't sensibly continue + if(!arg->type && (as<OverloadedExpr>(arg) || as<OverloadedExpr2>(arg))) + { + getSink()->diagnose( + expr->loc, + Diagnostics::overloadedParameterToHigherOrderFunction); + return CreateErrorExpr(expr); + } } for (auto& arg : expr->arguments) diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index d4bb63f52..25c86fa05 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -480,6 +480,8 @@ DIAGNOSTIC(39999, Warning, floatLiteralTooSmall, "'$1' is smaller than the small DIAGNOSTIC(39999, Error, unableToFindSymbolInModule, "unable to find the mangled symbol '$0' in module '$1'") +DIAGNOSTIC(39999, Error, overloadedParameterToHigherOrderFunction, "passing overloaded functions to higher order functions is not supported") + // 38xxx DIAGNOSTIC(38000, Error, entryPointFunctionNotFound, "no function found matching entry point name '$0'") diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 88b616996..8a04e5cc8 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -12,6 +12,7 @@ #include "slang-ir-dce.h" #include "slang-ir-diff-call.h" #include "slang-ir-autodiff.h" +#include "slang-ir-defunctionalization.h" #include "slang-ir-dll-export.h" #include "slang-ir-dll-import.h" #include "slang-ir-eliminate-phis.h" @@ -616,6 +617,13 @@ Result linkAndOptimizeIR( break; } + // Few of our targets support higher order functions, and + // we don't have the backend code to emit higher order functions for those + // which do. + // Specialize away these parameters + // TODO: We should implement a proper defunctionalization pass + specializeHigherOrderParameters(codeGenContext, irModule); + // For all targets, we translate load/store operations // of aggregate types from/to byte-address buffers into // stores of individual scalar or vector values. diff --git a/source/slang/slang-ir-defunctionalization.cpp b/source/slang/slang-ir-defunctionalization.cpp new file mode 100644 index 000000000..ac4258753 --- /dev/null +++ b/source/slang/slang-ir-defunctionalization.cpp @@ -0,0 +1,39 @@ +#include "slang-ir-defunctionalization.h" + +#include "slang-ir-insts.h" +#include "slang-ir-specialize-function-call.h" +#include "slang-ir-ssa-simplification.h" +#include "slang-ir.h" + +namespace Slang +{ + +struct FunctionParameterSpecializationCondition : FunctionCallSpecializeCondition +{ + TargetRequest* targetRequest = nullptr; + + bool doesParamWantSpecialization(IRParam* param, IRInst* /*arg*/) + { + IRType* type = param->getDataType(); + return as<IRFuncType>(type); + } +}; + +bool specializeHigherOrderParameters( + CodeGenContext* codeGenContext, + IRModule* module) +{ + bool result = false; + FunctionParameterSpecializationCondition condition; + condition.targetRequest = codeGenContext->getTargetReq(); + bool changed = true; + while (changed) + { + changed = specializeFunctionCalls(codeGenContext, module, &condition); + simplifyIR(module); + result |= changed; + } + return result; +} + +} // namespace Slang diff --git a/source/slang/slang-ir-defunctionalization.h b/source/slang/slang-ir-defunctionalization.h new file mode 100644 index 000000000..b1be8e3aa --- /dev/null +++ b/source/slang/slang-ir-defunctionalization.h @@ -0,0 +1,19 @@ +// Aspirational filename +#pragma once + +namespace Slang +{ + struct CodeGenContext; + struct IRModule; + struct IRType; + + /// Specialize calls to higher order functions + /// + /// This pass will rewrite any calls to higher order functions passing + /// global functions with calls to specialized versions simply + /// referencing the global. + /// + bool specializeHigherOrderParameters( + CodeGenContext* codeGenContext, + IRModule* module); +} diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp index 9b7cdaea6..bc238e6ec 100644 --- a/source/slang/slang-ir-specialize-function-call.cpp +++ b/source/slang/slang-ir-specialize-function-call.cpp @@ -31,6 +31,9 @@ bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization(IRParam* // if (as<IRGlobalParam>(arg)) return true; + // Similarly for these global values + if( as<IRGlobalValueWithCode>(arg) ) return true; + // As we will see later, we can also // specialize a call when the argument // is the result of indexing into an @@ -505,6 +508,11 @@ struct FunctionParameterSpecializationContext // ioInfo.key.vals.add(oldGlobalParam); } + else if( auto globalConstant = as<IRGlobalValueWithCode>(oldArg) ) + { + // Similarly for other global constants + ioInfo.key.vals.add(globalConstant); + } else if( oldArg->getOp() == kIROp_GetElement ) { // This is the case where the `oldArg` is @@ -626,6 +634,12 @@ struct FunctionParameterSpecializationContext // return globalParam; } + if( auto globalFunc = as<IRGlobalValueWithCode>(oldArg) ) + { + // As above, the identity of the specialized function is sufficient + // to resolve the uses + return globalFunc; + } else if( oldArg->getOp() == kIROp_GetElement ) { // This is the case where the argument is diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 433047741..e9c5f8fe5 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -411,6 +411,24 @@ public: return dispatchIfNotNull(expr->originalExpr); } bool visitModifiedTypeExpr(ModifiedTypeExpr* expr) { return dispatchIfNotNull(expr->base.exp); } + bool visitFuncTypeExpr(FuncTypeExpr* expr) + { + for(const auto& t : expr->parameters) + { + if(!dispatchIfNotNull(t.exp)) + return false; + } + return dispatchIfNotNull(expr->result.exp); + } + bool visitTupleTypeExpr(TupleTypeExpr* expr) + { + for(auto t : expr->members) + { + if(dispatchIfNotNull(t.exp)) + return true; + } + return false; + } bool visitTryExpr(TryExpr* expr) { return dispatchIfNotNull(expr->base); } bool visitHigherOrderInvokeExpr(HigherOrderInvokeExpr* expr) { diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp index fc8331bce..d95f7d4e5 100644 --- a/source/slang/slang-language-server-completion.cpp +++ b/source/slang/slang-language-server-completion.cpp @@ -22,7 +22,8 @@ static const char* kDeclKeywords[] = { "class", "struct", "interface", "public", "private", "internal", "protected", "typedef", "typealias", "uniform", "export", "groupshared", "extension", "associatedtype", "namespace", "This", "using", - "__generic", "__exported", "import", "enum", "cbuffer", "tbuffer", "func"}; + "__generic", "__exported", "import", "enum", "cbuffer", "tbuffer", "func", + "functype"}; static const char* kStmtKeywords[] = { "if", "else", "switch", "case", "default", "return", "try", "throw", "throws", "catch", "while", "for", @@ -33,7 +34,7 @@ static const char* kStmtKeywords[] = { "extension", "associatedtype", "this", "namespace", "This", "using", "__generic", "__exported", "import", "enum", "break", "continue", "discard", "defer", "cbuffer", "tbuffer", "func", "is", - "as", "nullptr", "none", "true", "false"}; + "as", "nullptr", "none", "true", "false", "functype"}; static const char* hlslSemanticNames[] = { "register", diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ad338709d..b1726487d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4512,6 +4512,18 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + LoweredValInfo visitFuncTypeExpr(FuncTypeExpr* /*expr*/) + { + SLANG_UNIMPLEMENTED_X("type expression during code generation"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + + LoweredValInfo visitTupleTypeExpr(TupleTypeExpr* /*expr*/) + { + SLANG_UNIMPLEMENTED_X("type expression during code generation"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + LoweredValInfo visitPointerTypeExpr(PointerTypeExpr* /*expr*/) { SLANG_UNIMPLEMENTED_X("'*' type expression during code generation"); diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index a8ab98254..ef9bf4938 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -177,7 +177,7 @@ namespace Slang break; default: - SLANG_UNEXPECTED("unimplemented case in mangling"); + SLANG_UNEXPECTED("unimplemented case in base type mangling"); break; } } @@ -238,9 +238,31 @@ namespace Slang { emitRaw(context, "E"); } + else if (const auto bottomType = dynamicCast<BottomType>(type)) + { + emitRaw(context, "B"); + } + else if (auto funcType = dynamicCast<FuncType>(type)) + { + emitRaw(context, "F"); + auto n = funcType->getParamCount(); + emit(context, n); + for(Index i = 0; i < n; ++i) + emitType(context, funcType->getParamType(i)); + emitType(context, funcType->getResultType()); + emitType(context, funcType->getErrorType()); + } + else if (auto tupleType = dynamicCast<TupleType>(type)) + { + emitRaw(context, "Tu"); + auto n = tupleType->getMemberCount(); + emit(context, n); + for(Index i = 0; i < n; ++i) + emitType(context, tupleType->getMember(i)); + } else { - SLANG_UNEXPECTED("unimplemented case in mangling"); + SLANG_UNEXPECTED("unimplemented case in type mangling"); } } @@ -307,7 +329,7 @@ namespace Slang } else { - SLANG_UNEXPECTED("unimplemented case in mangling"); + SLANG_UNEXPECTED("unimplemented case in val mangling"); } } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 78433a96f..cce4b7e7b 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2243,6 +2243,39 @@ namespace Slang return parseThisTypeExpr(parser); } + // (a,b,c) style tuples, curently unused +#if 0 + static Expr* parseTupleTypeExpr(Parser* parser) + { + parser->ReadToken(TokenType::LParent); + TupleTypeExpr* expr = parser->astBuilder->create<TupleTypeExpr>(); + while(!AdvanceIfMatch(parser, MatchedTokenType::Parentheses)) + { + expr->members.add(parser->ParseTypeExp()); + if(AdvanceIf(parser, TokenType::RParent)) + break; + parser->ReadToken(TokenType::Comma); + } + return expr; + } +#endif + + static Expr* parseFuncTypeExpr(Parser* parser) + { + parser->ReadToken(TokenType::LParent); + auto expr = parser->astBuilder->create<FuncTypeExpr>(); + while(!AdvanceIfMatch(parser, MatchedTokenType::Parentheses)) + { + expr->parameters.add(parser->ParseTypeExp()); + if(AdvanceIf(parser, TokenType::RParent)) + break; + parser->ReadToken(TokenType::Comma); + } + parser->ReadToken(TokenType::RightArrow); + expr->result = parser->ParseTypeExp(); + return expr; + } + /// Apply the given `modifiers` (if any) to the given `typeExpr` static Expr* _applyModifiersToTypeExpr(Parser* parser, Expr* typeExpr, Modifiers const& modifiers) { @@ -2438,6 +2471,17 @@ namespace Slang typeSpec.expr = parseThisTypeExpr(parser); return typeSpec; } + // Uncomment should we decide to enable (a,b,c) tuple types + // else if(parser->LookAheadToken(TokenType::LParent)) + // { + // typeSpec.expr = parseTupleTypeExpr(parser); + // return typeSpec; + // } + else if(AdvanceIf(parser, "functype")) + { + typeSpec.expr = parseFuncTypeExpr(parser); + return typeSpec; + } Token typeName = parser->ReadToken(TokenType::Identifier); @@ -4820,6 +4864,8 @@ namespace Slang return parsePostfixTypeSuffix(parser, typeExpr); } + static Expr* _parseInfixTypeExpr(Parser* parser); + static Expr* _parseInfixTypeExprSuffix(Parser* parser, Expr* leftExpr) { for(;;) @@ -4829,16 +4875,20 @@ namespace Slang // a conjunction type expression. auto loc = peekToken(parser).loc; - if(!AdvanceIf(parser, TokenType::OpBitAnd)) - break; - - auto rightExpr = _parsePostfixTypeExpr(parser); + if(AdvanceIf(parser, TokenType::OpBitAnd)) + { + auto rightExpr = _parsePostfixTypeExpr(parser); - auto andExpr = parser->astBuilder->create<AndTypeExpr>(); - andExpr->loc = loc; - andExpr->left = TypeExp(leftExpr); - andExpr->right = TypeExp(rightExpr); - leftExpr = andExpr; + auto andExpr = parser->astBuilder->create<AndTypeExpr>(); + andExpr->loc = loc; + andExpr->left = TypeExp(leftExpr); + andExpr->right = TypeExp(rightExpr); + leftExpr = andExpr; + } + else + { + break; + } } return leftExpr; @@ -4846,8 +4896,9 @@ namespace Slang /// Parse an infix type expression. /// - /// Currently, the only infix type expression we support is the `&` - /// operator for forming interface conjunctions. + /// Currently, the only infix type expressions we support are the `&` + /// operator for forming interface conjunctions and the `->` operator + /// for functions. /// static Expr* _parseInfixTypeExpr(Parser* parser) { |
