Skip to content

Commit 9fa8554

Browse files
authored
Fix NVDEC -> NPP CUDA stream sync issue (#868)
1 parent cb5c614 commit 9fa8554

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,32 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
275275
}
276276

277277
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_);
278-
nppCtx_->hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream();
278+
279+
// Create a CUDA event and attach it to the AVFrame's CUDA stream. That's the
280+
// NVDEC stream, i.e. the CUDA stream that the frame was decoded on.
281+
// We will be waiting for this event to complete before calling the NPP
282+
// functions, to ensure NVDEC has finished decoding the frame before running
283+
// the NPP color-conversion.
284+
// Note that our code is generic and assumes that the NVDEC's stream can be
285+
// arbitrary, but unfortunately we know it's hardcoded to be the default
286+
// stream by FFmpeg:
287+
// https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
288+
TORCH_CHECK(
289+
hwFramesCtx->device_ctx != nullptr,
290+
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
291+
auto cudaDeviceCtx =
292+
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
293+
at::cuda::CUDAEvent nvdecDoneEvent;
294+
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
295+
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
296+
nvdecDoneEvent.record(nvdecStream);
297+
298+
// Don't start NPP work before NVDEC is done decoding the frame!
299+
at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream(deviceIndex);
300+
nvdecDoneEvent.block(nppStream);
301+
302+
// Create the NPP context if we haven't yet.
303+
nppCtx_->hStream = nppStream.stream();
279304
cudaError_t err =
280305
cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags);
281306
TORCH_CHECK(

0 commit comments

Comments
 (0)