You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
4.3 KiB

  1. package utils
  2. import (
  3. "fmt"
  4. "go/ast"
  5. "go/parser"
  6. "go/token"
  7. "io/ioutil"
  8. "strings"
  9. )
  10. //@author: [LeonardWang](https://github.com/WangLeonard)
  11. //@function: AutoInjectionCode
  12. //@description: 向文件中固定注释位置写入代码
  13. //@param: filepath string, funcName string, codeData string
  14. //@return: error
  15. func AutoInjectionCode(filepath string, funcName string, codeData string) error {
  16. startComment := "Code generated by gin-vue-admin Begin; DO NOT EDIT."
  17. endComment := "Code generated by gin-vue-admin End; DO NOT EDIT."
  18. srcData, err := ioutil.ReadFile(filepath)
  19. if err != nil {
  20. return err
  21. }
  22. srcDataLen := len(srcData)
  23. fset := token.NewFileSet()
  24. fparser, err := parser.ParseFile(fset, filepath, srcData, parser.ParseComments)
  25. if err != nil {
  26. return err
  27. }
  28. codeData = strings.TrimSpace(codeData)
  29. var codeStartPos = -1
  30. var codeEndPos = srcDataLen
  31. var expectedFunction *ast.FuncDecl
  32. var startCommentPos = -1
  33. var endCommentPos = srcDataLen
  34. // 如果指定了函数名,先寻找对应函数
  35. if funcName != "" {
  36. for _, decl := range fparser.Decls {
  37. if funDecl, ok := decl.(*ast.FuncDecl); ok && funDecl.Name.Name == funcName {
  38. expectedFunction = funDecl
  39. codeStartPos = int(funDecl.Body.Lbrace)
  40. codeEndPos = int(funDecl.Body.Rbrace)
  41. break
  42. }
  43. }
  44. }
  45. // 遍历所有注释
  46. for _, comment := range fparser.Comments {
  47. if int(comment.Pos()) > codeStartPos && int(comment.End()) <= codeEndPos {
  48. if startComment != "" && strings.Contains(comment.Text(), startComment) {
  49. startCommentPos = int(comment.Pos()) // Note: Pos is the second '/'
  50. }
  51. if endComment != "" && strings.Contains(comment.Text(), endComment) {
  52. endCommentPos = int(comment.Pos()) // Note: Pos is the second '/'
  53. }
  54. }
  55. }
  56. if endCommentPos == srcDataLen {
  57. return fmt.Errorf("comment:%s not found", endComment)
  58. }
  59. // 在指定函数名,且函数中startComment和endComment都存在时,进行区间查重
  60. if (codeStartPos != -1 && codeEndPos <= srcDataLen) && (startCommentPos != -1 && endCommentPos != srcDataLen) && expectedFunction != nil {
  61. if exist := checkExist(&srcData, startCommentPos, endCommentPos, expectedFunction.Body, codeData); exist {
  62. fmt.Printf("文件 %s 待插入数据 %s 已存在\n", filepath, codeData)
  63. return nil // 这里不需要返回错误?
  64. }
  65. }
  66. // 两行注释中间没有换行时,会被认为是一条Comment
  67. if startCommentPos == endCommentPos {
  68. endCommentPos = startCommentPos + strings.Index(string(srcData[startCommentPos:]), endComment)
  69. for srcData[endCommentPos] != '/' {
  70. endCommentPos--
  71. }
  72. }
  73. // 记录"//"之前的空字符,保持写入后的格式一致
  74. tmpSpace := make([]byte, 0, 10)
  75. for tmp := endCommentPos - 2; tmp >= 0; tmp-- {
  76. if srcData[tmp] != '\n' {
  77. tmpSpace = append(tmpSpace, srcData[tmp])
  78. } else {
  79. break
  80. }
  81. }
  82. reverseSpace := make([]byte, 0, len(tmpSpace))
  83. for index := len(tmpSpace) - 1; index >= 0; index-- {
  84. reverseSpace = append(reverseSpace, tmpSpace[index])
  85. }
  86. // 插入数据
  87. indexPos := endCommentPos - 1
  88. insertData := []byte(append([]byte(codeData+"\n"), reverseSpace...))
  89. remainData := append([]byte{}, srcData[indexPos:]...)
  90. srcData = append(append(srcData[:indexPos], insertData...), remainData...)
  91. // 写回数据
  92. return ioutil.WriteFile(filepath, srcData, 0600)
  93. }
  94. func checkExist(srcData *[]byte, startPos int, endPos int, blockStmt *ast.BlockStmt, target string) bool {
  95. for _, list := range blockStmt.List {
  96. switch stmt := list.(type) {
  97. case *ast.ExprStmt:
  98. if callExpr, ok := stmt.X.(*ast.CallExpr); ok &&
  99. int(callExpr.Pos()) > startPos && int(callExpr.End()) < endPos {
  100. text := string((*srcData)[int(callExpr.Pos()-1):int(callExpr.End())])
  101. key := strings.TrimSpace(text)
  102. if key == target {
  103. return true
  104. }
  105. }
  106. case *ast.BlockStmt:
  107. if checkExist(srcData, startPos, endPos, stmt, target) {
  108. return true
  109. }
  110. case *ast.AssignStmt:
  111. // 为 model 中的代码进行检查
  112. if len(stmt.Rhs) > 0 {
  113. if callExpr, ok := stmt.Rhs[0].(*ast.CallExpr); ok {
  114. for _, arg := range callExpr.Args {
  115. if int(arg.Pos()) > startPos && int(arg.End()) < endPos {
  116. text := string((*srcData)[int(arg.Pos()-1):int(arg.End())])
  117. key := strings.TrimSpace(text)
  118. if key == target {
  119. return true
  120. }
  121. }
  122. }
  123. }
  124. }
  125. }
  126. }
  127. return false
  128. }