import com.intellij.database.model.DasTable
import com.intellij.database.util.Case
import com.intellij.database.util.DasUtil

import java.time.LocalDate
import java.time.format.DateTimeFormatter

/*
* Available context bindings:
* SELECTION Iterable<DasObject>
* PROJECT project
* FILES files helper
*/

packageName = ""
typeMapping = [
        (~/(?i)tinyint|smallint|mediumint/)      : "Integer",
        (~/(?i)int/)                             : "Integer",
        (~/(?i)bool|bit/)                        : "Boolean",
        (~/(?i)float|double|decimal|real/)       : "Double",
        (~/(?i)datetime|timestamp/)              : "LocalDateTime",
        (~/(?i)date/)                            : "LocalDate",
        (~/(?i)time/)                            : "LocalTime",
        (~/(?i)blob|binary|bfile|clob|raw|image/): "InputStream",
        (~/(?i)/)                                : "String"
]

currentDay = LocalDate.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd"))

FILES.chooseDirectoryAndSave("Choose directory", "Choose where to store generated files") { dir ->
    SELECTION.filter { it instanceof DasTable }.each { generate(it, dir) }
}

def generate(table, dir) {
    def className = javaName(table.getName(), true)
    packageName = getPackageName(dir)
    def fields = calcFields(table)
    new File(dir, className + ".java").withPrintWriter("UTF-8") { out -> generate(out, table, className, fields) }
}

// 获取包所在文件夹路径
def getPackageName(dir) {
    return dir.toString().replaceAll("\\\\", ".").replaceAll("/", ".").replaceAll("^.*src(\\.main\\.java\\.)?", "") + ";"
}

def containsDate(fields, dateType) {
    def f = fields.find { it.type == dateType }

    return f != null
}

def generate(out, table, className, fields) {
    def hasLocalDateTime = containsDate(fields, "LocalDateTime")
    def hasLocalDate = containsDate(fields, "LocalDate")
    def hasLocalTime = containsDate(fields, "LocalTime")

    out.println "package $packageName"
    out.println ""

    out.println "import lombok.Data;"
    out.println "import java.io.Serializable;"
    out.println "import javax.persistence.Entity;"
    out.println "import com.baomidou.mybatisplus.annotation.TableId;"
    out.println "import com.baomidou.mybatisplus.annotation.TableName;"

    if (hasLocalDateTime || hasLocalDate || hasLocalTime) {
        out.println "import com.vivo.util.dateTime.ZeusDateTimeUtils;"
        out.println "import com.fasterxml.jackson.annotation.JsonFormat;"
        out.println "import org.springframework.format.annotation.DateTimeFormat;"
    }

    if (hasLocalDateTime) {
        out.println "import java.time.LocalDateTime;"
    }

    if (hasLocalDate) {
        out.println "import java.time.LocalDate;"
    }

    if (hasLocalTime) {
        out.println "import java.time.LocalTime;"
    }
    out.println ""
    out.println "/**"
    out.println " * Copyright (c) vivo Information Technology, Inc."
    out.println " * All rights reserved."
    out.println " *"
    out.println " * @author Zhiyong Yang"
    out.println " * Description: ${table.getComment()}"
    out.println " * Changelog:"
    out.println " * Revision 1.0 ${currentDay} Zhiyong Yang"
    out.println " * - initialization"
    out.println " */"
    out.println "@Data"
    out.println "@Entity"
    out.println "@TableName(\"${table.getName()}\")"
    out.println "public class $className implements Serializable {"
    out.println ""
    out.println genSerialID()
    fields.each() {
        out.println "\t/**"
        out.println "\t* ${it.comment}"
        out.println "\t*/"

        if (it.primary) {
            if (("int".equals(it.type.toString()) || "Long".equalsIgnoreCase(it.type.toString()))) {
                out.println "\t@TableId(value = \"id\", type = IdType.AUTO)"
            } else {
                out.println "\t@TableId(value = \"${it.name}\")"
            }
        }

        switch (it.type) {
            case "LocalDateTime":
                out.println "\t@JsonFormat(pattern = ZeusDateTimeUtils.DEFAULT_TIME_PATTERN, timezone = ZeusDateTimeUtils.SHANG_HAI)"
                out.println "\t@DateTimeFormat(pattern = ZeusDateTimeUtils.DEFAULT_TIME_PATTERN)"
                break
            case "LocalDate":
                out.println "\t@JsonFormat(pattern = ZeusDateTimeUtils.DEFAULT_DATE_PATTERN, timezone = ZeusDateTimeUtils.SHANG_HAI)"
                out.println "\t@DateTimeFormat(pattern = ZeusDateTimeUtils.DEFAULT_DATE_PATTERN)"
                break
        }

        out.println "\tprivate ${it.type} ${it.name};"
        out.println ""
    }
    out.println "}"
}

def calcFields(table) {
    DasUtil.getColumns(table).reduce([]) { fields, col ->
        def spec = Case.LOWER.apply(col.getDataType().getSpecification())
        def typeStr = typeMapping.find { p, t -> p.matcher(spec).find() }.value
        fields += [[
                           name   : javaName(col.getName(), false),
                           type   : typeStr,
                           primary: DasUtil.isPrimary(col),
                           comment: col.getComment(),
                           annos  : ""]]
    }
}

def javaName(str, capitalize) {
    def s = com.intellij.psi.codeStyle.NameUtil.splitNameIntoWords(str)
            .collect { Case.LOWER.apply(it).capitalize() }
            .join("")
            .replaceAll(/[^\p{javaJavaIdentifierPart}[_]]/, "_")
    capitalize || s.length() == 1 ? s : Case.LOWER.apply(s[0]) + s[1..-1]
}

static String genSerialID() {
    return "\tprivate static final long serialVersionUID = " + Math.abs(new Random().nextLong()) + "L;"
}

 

posted on 2025-06-11 10:08  java先生  阅读(8)  评论(0)    收藏  举报