summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-expr.cpp14
-rw-r--r--source/slang/slang-emit.cpp3
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp12
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp110
-rw-r--r--source/slang/slang-ir-synthesize-active-mask.cpp59
7 files changed, 145 insertions, 59 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index abdd89b01..05a6ed249 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -422,7 +422,9 @@ namespace Slang
return result;
}
- Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(LookupResultItem const& item, Expr* originalExpr)
+ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(
+ LookupResultItem const& item,
+ Expr* originalExpr)
{
// If the only result from lookup is an entry in an interface decl, it could be that
// the user is leaving out an explicit definition for the requirement and depending on
@@ -521,13 +523,16 @@ namespace Slang
conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
conformanceDecl->parentDecl = structDecl;
structDecl->members.add(conformanceDecl);
+ structDecl->parentDecl = parent;
synthesizedDecl = structDecl;
auto typeDef = m_astBuilder->create<TypeAliasDecl>();
typeDef->nameAndLoc.name = getName("Differential");
- auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
- typeDef->type.type = DeclRefType::create(m_astBuilder, declRef);
typeDef->parentDecl = structDecl;
+
+ auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
+
+ typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
structDecl->members.add(typeDef);
}
break;
@@ -545,8 +550,9 @@ namespace Slang
auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
addModifier(synthesizedDecl, toBeSynthesized);
+ auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl);
return ConstructDeclRefExpr(
- makeDeclRef(synthesizedDecl),
+ synthDeclMemberRef,
nullptr,
originalExpr ? originalExpr->loc : SourceLoc(),
originalExpr);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index bbf6885a8..86136a010 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -459,6 +459,9 @@ Result linkAndOptimizeIR(
break;
}
+ if (sink->getErrorCount() != 0)
+ return SLANG_FAIL;
+
// If we have a target that is GPU like we use the string hashing mechanism
// but for that to work we need to inline such that calls (or returns) of strings
// boil down into getStringHash(stringLiteral)
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index bd32a1896..026b8b110 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -552,8 +552,6 @@ INST(TargetSwitch, targetSwitch, 1, 0)
// A generic asm inst has an return semantics that terminates the control flow.
INST(GenericAsm, GenericAsm, 1, 0)
-INST(RequirePrelude, RequirePrelude, 1, 0)
-
INST(discard, discard, 0, 0)
/* IRUnreachable */
@@ -563,6 +561,8 @@ INST_RANGE(Unreachable, MissingReturn, Unreachable)
INST_RANGE(TerminatorInst, Return, Unreachable)
+INST(RequirePrelude, RequirePrelude, 1, 0)
+
// TODO: We should consider splitting the basic arithmetic/comparison
// ops into cases for signed integers, unsigned integers, and floating-point
// values, to better match downstream targets that want to treat them
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 070f989b5..c04450b82 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3040,7 +3040,7 @@ struct IRSPIRVAsm : IRInst
}
};
-struct IRGenericAsm : IRInst
+struct IRGenericAsm : IRTerminatorInst
{
IR_LEAF_ISA(GenericAsm)
UnownedStringSlice getAsm() { return as<IRStringLit>(getOperand(0))->getStringSlice(); }
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
index b5af2d974..6970942c9 100644
--- a/source/slang/slang-ir-loop-unroll.cpp
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -472,15 +472,9 @@ bool unrollLoopsInModule(IRModule* module, DiagnosticSink* sink)
for (auto inst : module->getGlobalInsts())
{
if (auto genFunc = as<IRGeneric>(inst))
- {
- if (auto func = as<IRGlobalValueWithCode>(findGenericReturnVal(genFunc)))
- {
- bool result = unrollLoopsInFunc(module, func, sink);
- if (!result)
- return false;
- }
- }
- else if (auto func = as<IRGlobalValueWithCode>(inst))
+ continue;
+
+ if (auto func = as<IRGlobalValueWithCode>(inst))
{
bool result = unrollLoopsInFunc(module, func, sink);
if (!result)
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index 41665ddf7..3a7e8b9fb 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -177,9 +177,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
}
List<IRInst*> resultElements;
auto elementType = arrayType->getElementType();
+ auto tupleElementType = translateToTupleType(builder, elementType);
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
- auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = builder.emitTargetTupleGetElement(tupleElementType, val, builder.getIntValue(builder.getIntType(), i));
auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement);
if (!convertedElement)
return nullptr;
@@ -346,7 +347,7 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
List<IRType*> fieldTypes;
for (auto field : as<IRStructType>(type)->getFields())
{
- fieldTypes.add(translateToHostType(builder, field->getFieldType(), func));
+ fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink));
}
auto hostStructType = builder->createStructType();
@@ -358,6 +359,13 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag
return hostStructType;
}
+ case kIROp_ArrayType:
+ {
+ auto elementType = translateToHostType(builder, as<IRArrayType>(type)->getElementType(), func, sink);
+ if (!elementType)
+ return nullptr;
+ return builder->getArrayType(elementType, as<IRArrayType>(type)->getElementCount());
+ }
default:
break;
}
@@ -422,13 +430,36 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
return builder->emitMakeStruct(cudaType, (UInt)resultFields.getCount(), resultFields.getBuffer());
}
+ case kIROp_ArrayType:
+ {
+ auto cudaArrayType = cast<IRArrayType>(cudaType);
+ auto hostArrayType = cast<IRArrayType>(hostType);
+
+ List<IRInst*> resultElements;
+ for (UInt i = 0; i < (UInt)cast<IRIntLit>(cudaArrayType->getElementCount())->getValue(); i++)
+ {
+ auto cudaElementType = cudaArrayType->getElementType();
+ auto hostElementType = hostArrayType->getElementType();
+ auto castedElement = castHostToCUDAType(
+ builder,
+ hostElementType,
+ cudaElementType,
+ builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i)));
+
+ SLANG_RELEASE_ASSERT(castedElement);
+ resultElements.add(castedElement);
+ }
+
+ return builder->emitMakeArray(cudaType, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
default:
break;
}
- // If translateToHostType worked correctly, we shouldn't get here.
- SLANG_UNREACHABLE("unhandled type");
+ // If translateToHostType worked correctly, there should be no unhandled cases here.
+ // However, we won't diagnose here since its already diagnosed in translateToHostType()
+ return nullptr;
}
void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* hostFunc)
@@ -553,6 +584,12 @@ IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, Diagno
auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink);
if (outType)
*outType = type;
+
+ if (!type || sink->getErrorCount() > 0)
+ {
+ return nullptr;
+ }
+
auto hostParam = builder->emitParam(type);
// Add a namehint to the param by appending the suffix "_host".
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
@@ -600,6 +637,38 @@ void markTypeForPyExport(IRType* type, DiagnosticSink* sink)
}
return;
}
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ IRBuilder builder(arrayType->getModule());
+ if (!arrayType->findDecoration<IRPyExportDecoration>())
+ builder.addPyExportDecoration(arrayType, UnownedStringSlice("Array"));
+
+ markTypeForPyExport(arrayType->getElementType(), sink);
+ return;
+ }
+}
+
+String tryGetExportTypeName(IRBuilder* builder, IRType* type)
+{
+ if (auto structType = as<IRStructType>(type))
+ {
+ if (auto pyExportDecoration = type->findDecoration<IRPyExportDecoration>())
+ return String(pyExportDecoration->getExportName());
+ else
+ return String("");
+ }
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ StringBuilder nameBuilder;
+ nameBuilder << "Array_";
+ nameBuilder << tryGetExportTypeName(builder, arrayType->getElementType());
+ nameBuilder << "_";
+ nameBuilder << cast<IRIntLit>(arrayType->getElementCount())->getValue();
+
+ return nameBuilder.produceString();
+ }
+ else
+ return String();
}
void generateReflectionForType(IRType* type, DiagnosticSink* sink)
@@ -609,7 +678,6 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
// The list will contain the names of all the fields of the type.
//
- // TODO: Fix this to avoid emitting the same type reflection multiple times.
if (!type->findDecoration<IRPyExportDecoration>())
return;
@@ -635,20 +703,32 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
else
fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
- if (!field->getFieldType()->findDecoration<IRPyExportDecoration>())
- {
- fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
- continue;
- }
-
auto fieldType = field->getFieldType();
+ auto exportName = tryGetExportTypeName(&builder, fieldType);
- fieldTypeNames.add(
- builder.emitGetNativeString(
- builder.getStringValue(fieldType->findDecoration<IRPyExportDecoration>()->getExportName())));
+ if (exportName.getLength() > 0)
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(exportName.getUnownedSlice())));
+ else
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice(""))));
}
break;
}
+ case kIROp_ArrayType:
+ {
+ auto elementType = as<IRArrayType>(type)->getElementType();
+ fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("type"))));
+ fieldTypeNames.add(
+ builder.emitGetNativeString(
+ builder.getStringValue(tryGetExportTypeName(&builder, elementType).getUnownedSlice())));
+
+ auto elementCount = as<IRIntLit>(as<IRArrayType>(type)->getElementCount());
+ fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("size"))));
+
+ StringBuilder elementCountStr;
+ elementCountStr << elementCount->getValue();
+ fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(elementCountStr.getUnownedSlice())));
+ break;
+ }
default:
break;
}
@@ -676,7 +756,7 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink)
// Set function name.
StringBuilder reflFuncExportName;
- reflFuncExportName << "__typeinfo__" << type->findDecoration<IRPyExportDecoration>()->getExportName();
+ reflFuncExportName << "__typeinfo__" << tryGetExportTypeName(&builder, type).getUnownedSlice();
builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice());
diff --git a/source/slang/slang-ir-synthesize-active-mask.cpp b/source/slang/slang-ir-synthesize-active-mask.cpp
index 75246d553..60e13b418 100644
--- a/source/slang/slang-ir-synthesize-active-mask.cpp
+++ b/source/slang/slang-ir-synthesize-active-mask.cpp
@@ -1855,37 +1855,40 @@ struct SynthesizeActiveMaskForFunctionContext
}
else if( toBlock->getPredecessors().getCount() > 1 )
{
- // If the target block is one with multiple
- // predecessors, such that it will have an
- // added block parameter (phi node) to select
- // the corect mask value, then we need to
- // pass along the mask value to use as an
- // additional argument on the unconditional branch.
- //
- // If the old unconditional branch was:
- //
- // <op>(arg0, arg1, arg2, ...);
- //
- // Then our new branch will be:
- //
- // <op>(arg0, arg1, arg2, ..., toActiveMask);
- //
- List<IRInst*> newOperands;
- UInt oldOperandCount = terminator->getOperandCount();
- for( UInt i = 0; i < oldOperandCount; ++i )
+ if (doesBlockNeedActiveMask(toBlock))
{
- newOperands.add(terminator->getOperand(i));
- }
- newOperands.add(toActiveMask);
+ // If the target block is one with multiple
+ // predecessors, such that it will have an
+ // added block parameter (phi node) to select
+ // the corect mask value, then we need to
+ // pass along the mask value to use as an
+ // additional argument on the unconditional branch.
+ //
+ // If the old unconditional branch was:
+ //
+ // <op>(arg0, arg1, arg2, ...);
+ //
+ // Then our new branch will be:
+ //
+ // <op>(arg0, arg1, arg2, ..., toActiveMask);
+ //
+ List<IRInst*> newOperands;
+ UInt oldOperandCount = terminator->getOperandCount();
+ for( UInt i = 0; i < oldOperandCount; ++i )
+ {
+ newOperands.add(terminator->getOperand(i));
+ }
+ newOperands.add(toActiveMask);
- IRInst* newTerminator = builder.emitIntrinsicInst(
- terminator->getFullType(),
- terminator->getOp(),
- newOperands.getCount(),
- newOperands.getBuffer());
+ IRInst* newTerminator = builder.emitIntrinsicInst(
+ terminator->getFullType(),
+ terminator->getOp(),
+ newOperands.getCount(),
+ newOperands.getBuffer());
- terminator->replaceUsesWith(newTerminator);
- terminator->removeAndDeallocate();
+ terminator->replaceUsesWith(newTerminator);
+ terminator->removeAndDeallocate();
+ }
}
else
{