@@ -275,7 +275,32 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
275
275
}
276
276
277
277
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex (device_);
278
- nppCtx_->hStream = at::cuda::getCurrentCUDAStream (deviceIndex).stream ();
278
+
279
+ // Create a CUDA event and attach it to the AVFrame's CUDA stream. That's the
280
+ // NVDEC stream, i.e. the CUDA stream that the frame was decoded on.
281
+ // We will be waiting for this event to complete before calling the NPP
282
+ // functions, to ensure NVDEC has finished decoding the frame before running
283
+ // the NPP color-conversion.
284
+ // Note that our code is generic and assumes that the NVDEC's stream can be
285
+ // arbitrary, but unfortunately we know it's hardcoded to be the default
286
+ // stream by FFmpeg:
287
+ // https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
288
+ TORCH_CHECK (
289
+ hwFramesCtx->device_ctx != nullptr ,
290
+ " The AVFrame's hw_frames_ctx does not have a device_ctx. " );
291
+ auto cudaDeviceCtx =
292
+ static_cast <AVCUDADeviceContext*>(hwFramesCtx->device_ctx ->hwctx );
293
+ at::cuda::CUDAEvent nvdecDoneEvent;
294
+ at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
295
+ c10::cuda::getStreamFromExternal (cudaDeviceCtx->stream , deviceIndex);
296
+ nvdecDoneEvent.record (nvdecStream);
297
+
298
+ // Don't start NPP work before NVDEC is done decoding the frame!
299
+ at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream (deviceIndex);
300
+ nvdecDoneEvent.block (nppStream);
301
+
302
+ // Create the NPP context if we haven't yet.
303
+ nppCtx_->hStream = nppStream.stream ();
279
304
cudaError_t err =
280
305
cudaStreamGetFlags (nppCtx_->hStream , &nppCtx_->nStreamFlags );
281
306
TORCH_CHECK (
0 commit comments