You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

180 lines
5.2 KiB

3 years ago
  1. package utils
  2. import (
  3. "errors"
  4. "fmt"
  5. "go/ast"
  6. "go/parser"
  7. "go/token"
  8. "io/ioutil"
  9. "strings"
  10. )
  11. //@author: [LeonardWang](https://github.com/WangLeonard)
  12. //@function: AutoInjectionCode
  13. //@description: 向文件中固定注释位置写入代码
  14. //@param: filepath string, funcName string, codeData string
  15. //@return: error
  16. const (
  17. startComment = "Code generated by github.com/flipped-aurora/yibu/server Begin; DO NOT EDIT."
  18. endComment = "Code generated by github.com/flipped-aurora/yibu/server End; DO NOT EDIT."
  19. )
  20. //@author: [LeonardWang](https://github.com/WangLeonard)
  21. //@function: AutoInjectionCode
  22. //@description: 向文件中固定注释位置写入代码
  23. //@param: filepath string, funcName string, codeData string
  24. //@return: error
  25. func AutoInjectionCode(filepath string, funcName string, codeData string) error {
  26. srcData, err := ioutil.ReadFile(filepath)
  27. if err != nil {
  28. return err
  29. }
  30. srcDataLen := len(srcData)
  31. fset := token.NewFileSet()
  32. fparser, err := parser.ParseFile(fset, filepath, srcData, parser.ParseComments)
  33. if err != nil {
  34. return err
  35. }
  36. codeData = strings.TrimSpace(codeData)
  37. codeStartPos := -1
  38. codeEndPos := srcDataLen
  39. var expectedFunction *ast.FuncDecl
  40. startCommentPos := -1
  41. endCommentPos := srcDataLen
  42. // 如果指定了函数名,先寻找对应函数
  43. if funcName != "" {
  44. for _, decl := range fparser.Decls {
  45. if funDecl, ok := decl.(*ast.FuncDecl); ok && funDecl.Name.Name == funcName {
  46. expectedFunction = funDecl
  47. codeStartPos = int(funDecl.Body.Lbrace)
  48. codeEndPos = int(funDecl.Body.Rbrace)
  49. break
  50. }
  51. }
  52. }
  53. // 遍历所有注释
  54. for _, comment := range fparser.Comments {
  55. if int(comment.Pos()) > codeStartPos && int(comment.End()) <= codeEndPos {
  56. if startComment != "" && strings.Contains(comment.Text(), startComment) {
  57. startCommentPos = int(comment.Pos()) // Note: Pos is the second '/'
  58. }
  59. if endComment != "" && strings.Contains(comment.Text(), endComment) {
  60. endCommentPos = int(comment.Pos()) // Note: Pos is the second '/'
  61. }
  62. }
  63. }
  64. if endCommentPos == srcDataLen {
  65. return fmt.Errorf("comment:%s not found", endComment)
  66. }
  67. // 在指定函数名,且函数中startComment和endComment都存在时,进行区间查重
  68. if (codeStartPos != -1 && codeEndPos <= srcDataLen) && (startCommentPos != -1 && endCommentPos != srcDataLen) && expectedFunction != nil {
  69. if exist := checkExist(&srcData, startCommentPos, endCommentPos, expectedFunction.Body, codeData); exist {
  70. fmt.Printf("文件 %s 待插入数据 %s 已存在\n", filepath, codeData)
  71. return nil // 这里不需要返回错误?
  72. }
  73. }
  74. // 两行注释中间没有换行时,会被认为是一条Comment
  75. if startCommentPos == endCommentPos {
  76. endCommentPos = startCommentPos + strings.Index(string(srcData[startCommentPos:]), endComment)
  77. for srcData[endCommentPos] != '/' {
  78. endCommentPos--
  79. }
  80. }
  81. // 记录"//"之前的空字符,保持写入后的格式一致
  82. tmpSpace := make([]byte, 0, 10)
  83. for tmp := endCommentPos - 2; tmp >= 0; tmp-- {
  84. if srcData[tmp] != '\n' {
  85. tmpSpace = append(tmpSpace, srcData[tmp])
  86. } else {
  87. break
  88. }
  89. }
  90. reverseSpace := make([]byte, 0, len(tmpSpace))
  91. for index := len(tmpSpace) - 1; index >= 0; index-- {
  92. reverseSpace = append(reverseSpace, tmpSpace[index])
  93. }
  94. // 插入数据
  95. indexPos := endCommentPos - 1
  96. insertData := []byte(append([]byte(codeData+"\n"), reverseSpace...))
  97. remainData := append([]byte{}, srcData[indexPos:]...)
  98. srcData = append(append(srcData[:indexPos], insertData...), remainData...)
  99. // 写回数据
  100. return ioutil.WriteFile(filepath, srcData, 0o600)
  101. }
  102. func checkExist(srcData *[]byte, startPos int, endPos int, blockStmt *ast.BlockStmt, target string) bool {
  103. for _, list := range blockStmt.List {
  104. switch stmt := list.(type) {
  105. case *ast.ExprStmt:
  106. if callExpr, ok := stmt.X.(*ast.CallExpr); ok &&
  107. int(callExpr.Pos()) > startPos && int(callExpr.End()) < endPos {
  108. text := string((*srcData)[int(callExpr.Pos()-1):int(callExpr.End())])
  109. key := strings.TrimSpace(text)
  110. if key == target {
  111. return true
  112. }
  113. }
  114. case *ast.BlockStmt:
  115. if checkExist(srcData, startPos, endPos, stmt, target) {
  116. return true
  117. }
  118. case *ast.AssignStmt:
  119. // 为 model 中的代码进行检查
  120. if len(stmt.Rhs) > 0 {
  121. if callExpr, ok := stmt.Rhs[0].(*ast.CallExpr); ok {
  122. for _, arg := range callExpr.Args {
  123. if int(arg.Pos()) > startPos && int(arg.End()) < endPos {
  124. text := string((*srcData)[int(arg.Pos()-1):int(arg.End())])
  125. key := strings.TrimSpace(text)
  126. if key == target {
  127. return true
  128. }
  129. }
  130. }
  131. }
  132. }
  133. }
  134. }
  135. return false
  136. }
  137. func AutoClearCode(filepath string, codeData string) error {
  138. srcData, err := ioutil.ReadFile(filepath)
  139. if err != nil {
  140. return err
  141. }
  142. srcData, err = cleanCode(codeData, string(srcData))
  143. if err != nil {
  144. return err
  145. }
  146. return ioutil.WriteFile(filepath, srcData, 0o600)
  147. }
  148. func cleanCode(clearCode string, srcData string) ([]byte, error) {
  149. bf := make([]rune, 0, 1024)
  150. for i, v := range srcData {
  151. if v == '\n' {
  152. if strings.TrimSpace(string(bf)) == clearCode {
  153. return append([]byte(srcData[:i-len(bf)]), []byte(srcData[i+1:])...), nil
  154. }
  155. bf = (bf)[:0]
  156. continue
  157. }
  158. bf = append(bf, v)
  159. }
  160. return []byte(srcData), errors.New("未找到内容")
  161. }