diff options
Diffstat (limited to 'Whisper/Utils/parallelFor.cpp')
| -rw-r--r-- | Whisper/Utils/parallelFor.cpp | 144 |
1 files changed, 144 insertions, 0 deletions
diff --git a/Whisper/Utils/parallelFor.cpp b/Whisper/Utils/parallelFor.cpp new file mode 100644 index 0000000..c2b324b --- /dev/null +++ b/Whisper/Utils/parallelFor.cpp @@ -0,0 +1,144 @@ +#include "stdafx.h" +#include "parallelFor.h" + +namespace +{ + class alignas( 64 ) ParallelForContext + { + volatile long threadIndex; + volatile HRESULT status; + + alignas( 64 ) void* const context; + const Whisper::pfnParallelForCallback pfn; + + static void __stdcall callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work ); + + public: + + ParallelForContext( void* ctx, Whisper::pfnParallelForCallback pfn ); + + PTP_WORK createWork(); + + HRESULT getStatus() const; + }; + + ParallelForContext::ParallelForContext( void* ctx, Whisper::pfnParallelForCallback callback ) : + threadIndex( 1 ), + status( S_FALSE ), + context( ctx ), + pfn( callback ) + { } + + PTP_WORK ParallelForContext::createWork() + { + return CreateThreadpoolWork( &callbackStatic, this, nullptr ); + } + + void __stdcall ParallelForContext::callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work ) + { + ParallelForContext& context = *(ParallelForContext*)pv; + int ith = InterlockedIncrement( &context.threadIndex ); + ith--; + const HRESULT hr = context.pfn( ith, context.context ); + if( SUCCEEDED( hr ) ) + return; + InterlockedCompareExchange( &context.status, hr, S_FALSE ); + } + + HRESULT ParallelForContext::getStatus() const + { + const HRESULT hr = status; + if( SUCCEEDED( hr ) ) + return S_OK; + return hr; + } +} + +namespace Whisper +{ + HRESULT parallelFor( pfnParallelForCallback pfn, int threadsCount, void* ctx ) + { + if( threadsCount < 1 ) + return E_BOUNDS; + if( threadsCount == 1 ) + return pfn( 0, ctx ); + + ParallelForContext context{ ctx, pfn }; + + PTP_WORK const pw = context.createWork(); + if( nullptr == pw ) + return getLastHr(); + + for( int i = 1; i < threadsCount; i++ ) + SubmitThreadpoolWork( pw ); + + const HRESULT hr0 = pfn( 0, ctx ); + + WaitForThreadpoolWorkCallbacks( pw, FALSE ); + CloseThreadpoolWork( pw ); + + if( FAILED( hr0 ) ) + return hr0; + return context.getStatus(); + } +} + +using namespace Whisper; + +ThreadPoolWork::~ThreadPoolWork() +{ + if( nullptr != work ) + { + CloseThreadpoolWork( work ); + work = nullptr; + } +} + +HRESULT ThreadPoolWork::create() +{ + if( nullptr == work ) + { + work = CreateThreadpoolWork( &callbackStatic, this, nullptr ); + if( nullptr != work ) + return S_OK; + return getLastHr(); + } + return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); +} + +HRESULT ThreadPoolWork::parallelFor( int threadsCount ) noexcept +{ + if( nullptr != work ) + { + if( threadsCount <= 1 ) + return threadPoolCallback( 0 ); + + threadIndex = 1; + status = S_FALSE; + for( int i = 1; i < threadsCount; i++ ) + SubmitThreadpoolWork( work ); + + const HRESULT hr0 = threadPoolCallback( 0 ); + + WaitForThreadpoolWorkCallbacks( work, FALSE ); + + if( FAILED( hr0 ) ) + return hr0; + if( SUCCEEDED( status ) ) + return S_OK; + return status; + } + + return OLE_E_BLANK; +} + +void __stdcall ThreadPoolWork::callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work ) +{ + ThreadPoolWork* tpw = (ThreadPoolWork*)pv; + int ith = InterlockedIncrement( &tpw->threadIndex ); + ith--; + const HRESULT hr = tpw->threadPoolCallback( ith ); + if( SUCCEEDED( hr ) ) + return; + InterlockedCompareExchange( &tpw->status, hr, S_FALSE ); +}
\ No newline at end of file |
