Skip to content

Commit 78d7550

Browse files
yongqiangmaSigureMoYqGe585
authored
Compatible with torch 3rd (#74402)
* [Convert] POC for PyTorch compat conversion convert torch C++ api to paddle api --------- Co-authored-by: SigureMo <[email protected]> Co-authored-by: Yuqiang Ge <[email protected]>
1 parent cc405f6 commit 78d7550

File tree

96 files changed

+7558
-63
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+7558
-63
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,9 @@ if(WITH_PROFILER)
602602
endif()
603603

604604
include_directories("${PADDLE_SOURCE_DIR}")
605+
include_directories("${PADDLE_SOURCE_DIR}/paddle/phi/api/include/compat/")
606+
include_directories(
607+
"${PADDLE_SOURCE_DIR}/paddle/phi/api/include/compat/torch/csrc/api/include/")
605608

606609
if(WITH_NV_JETSON)
607610
set(WITH_ARM

paddle/fluid/pybind/pybind.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ limitations under the License. */
8080
#include "paddle/fluid/imperative/amp_auto_cast.h"
8181
#include "paddle/fluid/imperative/layer.h"
8282
#include "paddle/fluid/prim/utils/utils.h"
83+
#include "paddle/fluid/pybind/torch_compat.h"
8384
#include "paddle/phi/common/bfloat16.h"
8485
#include "paddle/phi/common/float16.h"
8586
#include "paddle/phi/common/int_array.h"
@@ -4139,6 +4140,9 @@ All parameter, weight, gradient are variables in Paddle.
41394140
BindVjp(&m);
41404141
BindDecompRule(&m);
41414142
BindDecompVjp(&m);
4143+
py::module torch_compat = m.def_submodule(
4144+
"torch_compat", "Compatibility layer for PyTorch-like APIs");
4145+
BindTorchCompat(&torch_compat);
41424146
#ifdef PADDLE_WITH_DISTRIBUTE
41434147
BindDistApi(&m);
41444148
#endif

0 commit comments

Comments
 (0)