Skip to content

Commit 43b577b

Browse files
committed
fix:
1 parent e1855c6 commit 43b577b

File tree

7 files changed

+173
-47
lines changed

7 files changed

+173
-47
lines changed

cmd/quickcopy/quickcopy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ package main
33
import "github.com/antlabs/quickcopy"
44

55
func main() {
6-
// quickcopy.Main("/Users/guonaihong/my-github/quickcopy/mytest/copy_slice")
6+
// quickcopy.Main("/Users/guonaihong/my-github/quickcopy/mytest/example")
77
quickcopy.Main(".")
88
}

import.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package quickcopy
2+
3+
import (
4+
"fmt"
5+
"go/ast"
6+
"go/token"
7+
"log"
8+
"strings"
9+
)
10+
11+
func addRequiredImports(file *ast.File, importPath ...string) {
12+
13+
log.Printf("addRequiredImports:%v\n", importPath)
14+
15+
// 需要添加的包
16+
requiredImports := make(map[string]bool)
17+
18+
for _, pkg := range importPath {
19+
requiredImports[pkg] = true
20+
}
21+
22+
// 查找现有的 import 声明
23+
var importDecl *ast.GenDecl
24+
for _, decl := range file.Decls {
25+
if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
26+
importDecl = genDecl
27+
break
28+
}
29+
}
30+
31+
// 如果没有 import 声明,创建一个新的
32+
if importDecl == nil {
33+
importSpecs := make([]ast.Spec, 0, len(requiredImports))
34+
for pkg := range requiredImports {
35+
importSpecs = append(importSpecs, &ast.ImportSpec{
36+
Path: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf(`"%s"`, pkg)},
37+
})
38+
}
39+
40+
// 将新的 import 声明添加到文件顶部
41+
file.Decls = append([]ast.Decl{&ast.GenDecl{
42+
Tok: token.IMPORT,
43+
Specs: importSpecs,
44+
}}, file.Decls...)
45+
return
46+
}
47+
48+
// 如果已有 import 声明,检查并添加缺失的包
49+
existingImports := make(map[string]bool)
50+
for _, spec := range importDecl.Specs {
51+
if importSpec, ok := spec.(*ast.ImportSpec); ok {
52+
existingImports[strings.Trim(importSpec.Path.Value, `"`)] = true
53+
}
54+
}
55+
56+
// 添加缺失的包
57+
for pkg := range requiredImports {
58+
if !existingImports[pkg] {
59+
log.Printf("Adding import: %s", pkg)
60+
importDecl.Specs = append(importDecl.Specs, &ast.ImportSpec{
61+
Path: &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf(`"%s"`, pkg)},
62+
})
63+
}
64+
}
65+
}

mytest/basic/basic_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@ func copySliceIntFromSliceString(src []string) []int {
4343
}(
4444
src[i])
4545
}
46-
return dst // :quickcopy
46+
return dst // :quickcopy
4747
}

mytest/example/example.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package example
2+
3+
import (
4+
"time"
5+
6+
"github.com/google/uuid"
7+
"fmt"
8+
)
9+
10+
// 源结构体
11+
type Source struct {
12+
Name string
13+
Age int
14+
Birthday time.Time
15+
ID uuid.UUID
16+
}
17+
18+
// 目标结构体
19+
type Destination struct {
20+
Name string
21+
Age string // 支持类型自动转换
22+
Birthday string // time.Time 将自动转为 RFC3339 格式
23+
ID string // UUID 将自动转为字符串
24+
}
25+
26+
// :quickcopy
27+
func CopyToDestination(dst *Destination, src *Source) {
28+
29+
dst.
30+
Name = src.Name
31+
32+
dst.Age = fmt.
33+
Sprint(src.Age)
34+
dst.Birthday = func(t time.Time) string {
35+
return t.Format(time.RFC3339)
36+
}(src.Birthday)
37+
38+
dst.ID = func(u uuid.UUID) string {
39+
return u.String()
40+
}(src.ID)
41+
}

quickcopy.go

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type FieldMapping struct {
3030
ConversionFunc string
3131
SrcElemType string // 新增
3232
DstElemType string // 新增
33+
ImportPath string // 依赖的包
3334
}
3435

3536
// CopyFuncInfo 存储拷贝函数信息
@@ -127,7 +128,7 @@ func processFields(
127128
// 处理类型转换
128129
srcType := types.ExprString(srcField.Type)
129130
dstType := types.ExprString(field.Type)
130-
conversion := getTypeConversion(srcType, dstType, allowNarrow, singleToSlice, file, path)
131+
conversion, importPath := getTypeConversion(srcType, dstType, allowNarrow, singleToSlice, file, path)
131132

132133
// 判断是否为嵌入字段
133134
isEmbedded := false
@@ -147,6 +148,7 @@ func processFields(
147148
ConversionFunc: getStructCopyFuncName(srcType, dstType),
148149
SrcElemType: getElementType(srcType),
149150
DstElemType: getElementType(dstType),
151+
ImportPath: importPath,
150152
})
151153

152154
if !isSrc {
@@ -271,6 +273,13 @@ func generateCopyFunctionIfNeeded(srcType, dstType string, file *ast.File, path
271273
}
272274
generateCompleteCopyFunc(funcDecl, "src", "dst", srcType, dstType, fields)
273275
// 注册生成的函数
276+
importPath := []string{}
277+
for _, field := range fields {
278+
if field.ImportPath != "" {
279+
importPath = append(importPath, field.ImportPath)
280+
}
281+
}
282+
addRequiredImports(file, importPath...)
274283
generatedFunctions.Store(funcName, funcDecl)
275284
}
276285

@@ -432,15 +441,16 @@ func getFieldMappings(srcType, dstType string, file *ast.File, ignoreCase, allow
432441
if isSliceOrArray(srcType) && isSliceOrArray(dstType) {
433442
srcElem := getElementType(srcType)
434443
dstElem := getElementType(dstType)
435-
conversion := getTypeConversion(srcElem, dstElem, allowNarrow, singleToSlice, file, path)
444+
conversion, _ := getTypeConversion(srcElem, dstElem, allowNarrow, singleToSlice, file, path)
436445
if conversion != "" {
437-
sliceConv := generateSliceCopyFunc(srcElem, dstElem, conversion, file, path)
446+
sliceConv, sliceImportPath := generateSliceCopyFunc(srcElem, dstElem, conversion, file, path)
438447
return []FieldMapping{
439448
{
440449
SrcField: "",
441450
DstField: "",
442451
Conversion: sliceConv,
443452
IsSlice: true,
453+
ImportPath: sliceImportPath,
444454
},
445455
}
446456
}
@@ -484,7 +494,7 @@ func getFieldMappings(srcType, dstType string, file *ast.File, ignoreCase, allow
484494
}
485495

486496
// 获取类型转换逻辑
487-
conversion := getTypeConversion(types.ExprString(srcField.Type), types.ExprString(dstField.Type), allowNarrow, singleToSlice, file, path)
497+
conversion, importPath := getTypeConversion(types.ExprString(srcField.Type), types.ExprString(dstField.Type), allowNarrow, singleToSlice, file, path)
488498

489499
// 判断是否为嵌入字段
490500
isEmbedded := isEmbeddedField(srcField) || isEmbeddedField(dstField)
@@ -495,6 +505,7 @@ func getFieldMappings(srcType, dstType string, file *ast.File, ignoreCase, allow
495505
Conversion: conversion,
496506
IsEmbedded: isEmbedded,
497507
ConversionFunc: getStructCopyFuncName(srcType, dstType),
508+
ImportPath: importPath,
498509
})
499510
log.Printf("Mapped field: %s -> %s (Conversion: %s)", srcFieldPath, dstFieldPath, conversion)
500511

@@ -576,10 +587,10 @@ func getSliceCopyFuncName(srcElem, dstElem string) string {
576587
}
577588

578589
// 修改 getTypeConversion 函数签名,增加 file 参数
579-
func getTypeConversion(srcType, dstType string, allowNarrow, singleToSlice bool, file *ast.File, path string) string {
590+
func getTypeConversion(srcType, dstType string, allowNarrow, singleToSlice bool, file *ast.File, path string) (string, string) {
580591
// 类型相同无需转换
581592
if srcType == dstType {
582-
return ""
593+
return "", ""
583594
}
584595

585596
if isSliceOrArray(srcType) && isSliceOrArray(dstType) {
@@ -591,7 +602,7 @@ func getTypeConversion(srcType, dstType string, allowNarrow, singleToSlice bool,
591602
}
592603
// 处理结构体类型
593604
if isStructType(srcType, file, path) && isStructType(dstType, file, path) {
594-
return generateStructConversionFunc(srcType, dstType)
605+
return generateStructConversionFunc(srcType, dstType), ""
595606
}
596607

597608
// 处理指针类型
@@ -644,17 +655,17 @@ func isStructType(typeName string, file *ast.File, path string) bool {
644655
}
645656

646657
// 核心处理函数
647-
func handleBasicConversion(src, dst string, allowNarrow bool) string {
658+
func handleBasicConversion(src, dst string, allowNarrow bool) (string, string) {
648659
// 整数类型转换
649660
if isIntegerType(src) && isIntegerType(dst) {
650661
srcWidth := getIntWidth(src)
651662
dstWidth := getIntWidth(dst)
652663

653664
if srcWidth > dstWidth && !allowNarrow {
654665
log.Printf("Narrowing conversion disabled: %s -> %s", src, dst)
655-
return ""
666+
return "", ""
656667
}
657-
return dst // 返回类型名称作为转换函数
668+
return dst, "" // 返回类型名称作为转换函数
658669
}
659670

660671
// 其他基本类型转换
@@ -702,13 +713,13 @@ func isPointerType(typeName string) bool {
702713
}
703714

704715
// 新增指针转换处理函数
705-
func handlePointerConversion(srcType, dstType string, allowNarrow, singleToSlice bool, file *ast.File, path string) string {
716+
func handlePointerConversion(srcType, dstType string, allowNarrow, singleToSlice bool, file *ast.File, path string) (string, string) {
706717
// 获取基础类型
707718
baseSrc := strings.TrimPrefix(srcType, "*")
708719
baseDst := strings.TrimPrefix(dstType, "*")
709720

710721
// 递归获取基础类型转换
711-
baseConv := getTypeConversion(baseSrc, baseDst, allowNarrow, singleToSlice, file, path)
722+
baseConv, importPath := getTypeConversion(baseSrc, baseDst, allowNarrow, singleToSlice, file, path)
712723

713724
// 生成指针转换逻辑
714725
return fmt.Sprintf(`func(src %s) %s {
@@ -718,7 +729,7 @@ func handlePointerConversion(srcType, dstType string, allowNarrow, singleToSlice
718729
dst := new(%s)
719730
%s
720731
return dst
721-
}`, srcType, dstType, baseDst, generateElementConversion("*src", "dst", baseConv))
732+
}`, srcType, dstType, baseDst, generateElementConversion("*src", "dst", baseConv)), importPath
722733
}
723734

724735
func Main(dir string) {
@@ -805,9 +816,16 @@ func Main(dir string) {
805816
// 提取字段映射关系
806817
fields := getFieldMappings(srcType, dstType, file, ignoreCase, allowNarrow, singleToSlice, fieldMappings, path)
807818

819+
importPath := []string{}
820+
for _, field := range fields {
821+
if field.ImportPath != "" {
822+
importPath = append(importPath, field.ImportPath)
823+
}
824+
}
808825
// 生成完整的拷贝函数
809826
generateCompleteCopyFunc(funcDecl, srcVar, dstVar, srcType, dstType, fields)
810-
827+
// 加入必要的导入
828+
addRequiredImports(file, importPath...)
811829
// 将修改后的 AST 写回文件
812830
writeFile(fset, file, path)
813831
return true

0 commit comments

Comments
 (0)