@@ -30,6 +30,7 @@ type FieldMapping struct {
30
30
ConversionFunc string
31
31
SrcElemType string // 新增
32
32
DstElemType string // 新增
33
+ ImportPath string // 依赖的包
33
34
}
34
35
35
36
// CopyFuncInfo 存储拷贝函数信息
@@ -127,7 +128,7 @@ func processFields(
127
128
// 处理类型转换
128
129
srcType := types .ExprString (srcField .Type )
129
130
dstType := types .ExprString (field .Type )
130
- conversion := getTypeConversion (srcType , dstType , allowNarrow , singleToSlice , file , path )
131
+ conversion , importPath := getTypeConversion (srcType , dstType , allowNarrow , singleToSlice , file , path )
131
132
132
133
// 判断是否为嵌入字段
133
134
isEmbedded := false
@@ -147,6 +148,7 @@ func processFields(
147
148
ConversionFunc : getStructCopyFuncName (srcType , dstType ),
148
149
SrcElemType : getElementType (srcType ),
149
150
DstElemType : getElementType (dstType ),
151
+ ImportPath : importPath ,
150
152
})
151
153
152
154
if ! isSrc {
@@ -271,6 +273,13 @@ func generateCopyFunctionIfNeeded(srcType, dstType string, file *ast.File, path
271
273
}
272
274
generateCompleteCopyFunc (funcDecl , "src" , "dst" , srcType , dstType , fields )
273
275
// 注册生成的函数
276
+ importPath := []string {}
277
+ for _ , field := range fields {
278
+ if field .ImportPath != "" {
279
+ importPath = append (importPath , field .ImportPath )
280
+ }
281
+ }
282
+ addRequiredImports (file , importPath ... )
274
283
generatedFunctions .Store (funcName , funcDecl )
275
284
}
276
285
@@ -432,15 +441,16 @@ func getFieldMappings(srcType, dstType string, file *ast.File, ignoreCase, allow
432
441
if isSliceOrArray (srcType ) && isSliceOrArray (dstType ) {
433
442
srcElem := getElementType (srcType )
434
443
dstElem := getElementType (dstType )
435
- conversion := getTypeConversion (srcElem , dstElem , allowNarrow , singleToSlice , file , path )
444
+ conversion , _ := getTypeConversion (srcElem , dstElem , allowNarrow , singleToSlice , file , path )
436
445
if conversion != "" {
437
- sliceConv := generateSliceCopyFunc (srcElem , dstElem , conversion , file , path )
446
+ sliceConv , sliceImportPath := generateSliceCopyFunc (srcElem , dstElem , conversion , file , path )
438
447
return []FieldMapping {
439
448
{
440
449
SrcField : "" ,
441
450
DstField : "" ,
442
451
Conversion : sliceConv ,
443
452
IsSlice : true ,
453
+ ImportPath : sliceImportPath ,
444
454
},
445
455
}
446
456
}
@@ -484,7 +494,7 @@ func getFieldMappings(srcType, dstType string, file *ast.File, ignoreCase, allow
484
494
}
485
495
486
496
// 获取类型转换逻辑
487
- conversion := getTypeConversion (types .ExprString (srcField .Type ), types .ExprString (dstField .Type ), allowNarrow , singleToSlice , file , path )
497
+ conversion , importPath := getTypeConversion (types .ExprString (srcField .Type ), types .ExprString (dstField .Type ), allowNarrow , singleToSlice , file , path )
488
498
489
499
// 判断是否为嵌入字段
490
500
isEmbedded := isEmbeddedField (srcField ) || isEmbeddedField (dstField )
@@ -495,6 +505,7 @@ func getFieldMappings(srcType, dstType string, file *ast.File, ignoreCase, allow
495
505
Conversion : conversion ,
496
506
IsEmbedded : isEmbedded ,
497
507
ConversionFunc : getStructCopyFuncName (srcType , dstType ),
508
+ ImportPath : importPath ,
498
509
})
499
510
log .Printf ("Mapped field: %s -> %s (Conversion: %s)" , srcFieldPath , dstFieldPath , conversion )
500
511
@@ -576,10 +587,10 @@ func getSliceCopyFuncName(srcElem, dstElem string) string {
576
587
}
577
588
578
589
// 修改 getTypeConversion 函数签名,增加 file 参数
579
- func getTypeConversion (srcType , dstType string , allowNarrow , singleToSlice bool , file * ast.File , path string ) string {
590
+ func getTypeConversion (srcType , dstType string , allowNarrow , singleToSlice bool , file * ast.File , path string ) ( string , string ) {
580
591
// 类型相同无需转换
581
592
if srcType == dstType {
582
- return ""
593
+ return "" , ""
583
594
}
584
595
585
596
if isSliceOrArray (srcType ) && isSliceOrArray (dstType ) {
@@ -591,7 +602,7 @@ func getTypeConversion(srcType, dstType string, allowNarrow, singleToSlice bool,
591
602
}
592
603
// 处理结构体类型
593
604
if isStructType (srcType , file , path ) && isStructType (dstType , file , path ) {
594
- return generateStructConversionFunc (srcType , dstType )
605
+ return generateStructConversionFunc (srcType , dstType ), ""
595
606
}
596
607
597
608
// 处理指针类型
@@ -644,17 +655,17 @@ func isStructType(typeName string, file *ast.File, path string) bool {
644
655
}
645
656
646
657
// 核心处理函数
647
- func handleBasicConversion (src , dst string , allowNarrow bool ) string {
658
+ func handleBasicConversion (src , dst string , allowNarrow bool ) ( string , string ) {
648
659
// 整数类型转换
649
660
if isIntegerType (src ) && isIntegerType (dst ) {
650
661
srcWidth := getIntWidth (src )
651
662
dstWidth := getIntWidth (dst )
652
663
653
664
if srcWidth > dstWidth && ! allowNarrow {
654
665
log .Printf ("Narrowing conversion disabled: %s -> %s" , src , dst )
655
- return ""
666
+ return "" , ""
656
667
}
657
- return dst // 返回类型名称作为转换函数
668
+ return dst , "" // 返回类型名称作为转换函数
658
669
}
659
670
660
671
// 其他基本类型转换
@@ -702,13 +713,13 @@ func isPointerType(typeName string) bool {
702
713
}
703
714
704
715
// 新增指针转换处理函数
705
- func handlePointerConversion (srcType , dstType string , allowNarrow , singleToSlice bool , file * ast.File , path string ) string {
716
+ func handlePointerConversion (srcType , dstType string , allowNarrow , singleToSlice bool , file * ast.File , path string ) ( string , string ) {
706
717
// 获取基础类型
707
718
baseSrc := strings .TrimPrefix (srcType , "*" )
708
719
baseDst := strings .TrimPrefix (dstType , "*" )
709
720
710
721
// 递归获取基础类型转换
711
- baseConv := getTypeConversion (baseSrc , baseDst , allowNarrow , singleToSlice , file , path )
722
+ baseConv , importPath := getTypeConversion (baseSrc , baseDst , allowNarrow , singleToSlice , file , path )
712
723
713
724
// 生成指针转换逻辑
714
725
return fmt .Sprintf (`func(src %s) %s {
@@ -718,7 +729,7 @@ func handlePointerConversion(srcType, dstType string, allowNarrow, singleToSlice
718
729
dst := new(%s)
719
730
%s
720
731
return dst
721
- }` , srcType , dstType , baseDst , generateElementConversion ("*src" , "dst" , baseConv ))
732
+ }` , srcType , dstType , baseDst , generateElementConversion ("*src" , "dst" , baseConv )), importPath
722
733
}
723
734
724
735
func Main (dir string ) {
@@ -805,9 +816,16 @@ func Main(dir string) {
805
816
// 提取字段映射关系
806
817
fields := getFieldMappings (srcType , dstType , file , ignoreCase , allowNarrow , singleToSlice , fieldMappings , path )
807
818
819
+ importPath := []string {}
820
+ for _ , field := range fields {
821
+ if field .ImportPath != "" {
822
+ importPath = append (importPath , field .ImportPath )
823
+ }
824
+ }
808
825
// 生成完整的拷贝函数
809
826
generateCompleteCopyFunc (funcDecl , srcVar , dstVar , srcType , dstType , fields )
810
-
827
+ // 加入必要的导入
828
+ addRequiredImports (file , importPath ... )
811
829
// 将修改后的 AST 写回文件
812
830
writeFile (fset , file , path )
813
831
return true
0 commit comments