Last active
April 9, 2025 22:19
-
-
Save suntong/56ff467c1479f6d9e892c5e9064f1125 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"encoding/csv" | |
"fmt" | |
"io" | |
"log" | |
"os" | |
"regexp" | |
"strconv" | |
"strings" | |
"unicode" | |
) | |
type Config struct { | |
InputFile string | |
TableName string | |
Delimiter rune | |
HasHeader bool | |
VarcharLength int | |
TextThreshold int | |
BatchInsert bool | |
BatchSize int | |
NullString string | |
PrimaryKeys []string | |
MaxSampleSize int | |
} | |
type CSVToMySQLConverter struct { | |
Config | |
ForceTypes map[string]string // column name -> MySQL type | |
SkipColumns map[string]bool // columns to skip | |
} | |
var ( | |
sanitizeRegex = regexp.MustCompile(`[^a-zA-Z0-9_]+`) | |
leadingRegex = regexp.MustCompile(`^[^a-zA-Z_]`) | |
) | |
func NewCSVToMySQLConverter() *CSVToMySQLConverter { | |
return &CSVToMySQLConverter{ | |
Config: Config{ | |
Delimiter: ',', | |
HasHeader: true, | |
VarcharLength: 255, | |
TextThreshold: 500, | |
BatchInsert: true, | |
BatchSize: 100, | |
NullString: "NULL", | |
MaxSampleSize: 1000, | |
}, | |
ForceTypes: make(map[string]string), | |
SkipColumns: make(map[string]bool), | |
} | |
} | |
func (c *CSVToMySQLConverter) Convert() (string, string, error) { | |
file, err := os.Open(c.InputFile) | |
if err != nil { | |
return "", "", fmt.Errorf("error opening file: %w", err) | |
} | |
defer file.Close() | |
reader := csv.NewReader(file) | |
reader.Comma = c.Delimiter | |
reader.TrimLeadingSpace = true | |
headers, err := c.readHeaders(reader) | |
if err != nil { | |
return "", "", fmt.Errorf("error reading headers: %w", err) | |
} | |
columnTypes, err := c.determineColumnTypes(reader, headers) | |
if err != nil { | |
return "", "", fmt.Errorf("error determining column types: %w", err) | |
} | |
createTable := c.generateCreateTable(headers, columnTypes) | |
inserts, err := c.generateInsertStatements(file, headers, columnTypes) | |
if err != nil { | |
return "", "", fmt.Errorf("error generating insert statements: %w", err) | |
} | |
return createTable, inserts, nil | |
} | |
func (c *CSVToMySQLConverter) readHeaders(reader *csv.Reader) ([]string, error) { | |
if c.HasHeader { | |
rawHeaders, err := reader.Read() | |
if err != nil { | |
return nil, fmt.Errorf("error reading header: %w", err) | |
} | |
headers := make([]string, len(rawHeaders)) | |
for i, h := range rawHeaders { | |
headers[i] = c.sanitizeColumnName(h) | |
if headers[i] == "" { | |
headers[i] = fmt.Sprintf("column_%d", i+1) | |
} | |
} | |
return headers, nil | |
} | |
firstRow, err := reader.Read() | |
if err != nil { | |
return nil, fmt.Errorf("error reading first row: %w", err) | |
} | |
headers := make([]string, len(firstRow)) | |
for i := range firstRow { | |
headers[i] = fmt.Sprintf("column_%d", i+1) | |
} | |
file, err := os.Open(c.InputFile) | |
if err != nil { | |
return nil, fmt.Errorf("error reopening file: %w", err) | |
} | |
defer file.Close() | |
reader = csv.NewReader(file) | |
reader.Comma = c.Delimiter | |
reader.TrimLeadingSpace = true | |
return headers, nil | |
} | |
func (c *CSVToMySQLConverter) sanitizeColumnName(name string) string { | |
name = sanitizeRegex.ReplaceAllString(strings.TrimSpace(name), "_") | |
name = strings.Trim(name, "_") | |
if leadingRegex.MatchString(name) { | |
name = "_" + name | |
} | |
// Convert to lowercase and replace spaces | |
var sb strings.Builder | |
for _, r := range strings.ToLower(name) { | |
if unicode.IsSpace(r) { | |
sb.WriteRune('_') | |
} else { | |
sb.WriteRune(r) | |
} | |
} | |
return sb.String() | |
} | |
func (c *CSVToMySQLConverter) determineColumnTypes(reader *csv.Reader, headers []string) ([]string, error) { | |
columnTypes := make([]string, len(headers)) | |
for i := range headers { | |
if forcedType, ok := c.ForceTypes[headers[i]]; ok { | |
columnTypes[i] = forcedType | |
} else if c.SkipColumns[headers[i]] { | |
columnTypes[i] = "SKIP" | |
} else { | |
columnTypes[i] = "TEXT" | |
} | |
} | |
sampleCount := 0 | |
for { | |
record, err := reader.Read() | |
if err == io.EOF { | |
break | |
} | |
if err != nil { | |
return nil, fmt.Errorf("error reading record: %w", err) | |
} | |
if len(record) != len(headers) { | |
log.Printf("Skipping row with %d columns (expected %d)", len(record), len(headers)) | |
continue | |
} | |
for i, value := range record { | |
if columnTypes[i] == "SKIP" { | |
continue | |
} | |
value = strings.TrimSpace(value) | |
if value == "" || strings.EqualFold(value, c.NullString) { | |
continue | |
} | |
if _, ok := c.ForceTypes[headers[i]]; !ok { | |
columnTypes[i] = c.refineType(columnTypes[i], value) | |
} | |
} | |
sampleCount++ | |
if sampleCount >= c.MaxSampleSize { | |
break | |
} | |
} | |
return columnTypes, nil | |
} | |
func (c *CSVToMySQLConverter) refineType(currentType, value string) string { | |
if isInteger(value) { | |
return "BIGINT" | |
} | |
if isDecimal(value) { | |
return "DECIMAL(20,6)" | |
} | |
if isDate(value) { | |
if len(value) > 10 { | |
return "DATETIME" | |
} | |
return "DATE" | |
} | |
length := len(value) | |
if length > c.TextThreshold { | |
return "TEXT" | |
} | |
if length > c.VarcharLength { | |
return fmt.Sprintf("VARCHAR(%d)", ((length/50)+1)*50) | |
} | |
return fmt.Sprintf("VARCHAR(%d)", c.VarcharLength) | |
} | |
func (c *CSVToMySQLConverter) generateCreateTable(headers []string, columnTypes []string) string { | |
var sb strings.Builder | |
sb.WriteString(fmt.Sprintf("CREATE TABLE `%s` (\n", c.TableName)) | |
columns := make([]string, 0, len(headers)) | |
for i, header := range headers { | |
if columnTypes[i] == "SKIP" { | |
continue | |
} | |
columns = append(columns, fmt.Sprintf(" `%s` %s", header, columnTypes[i])) | |
} | |
if len(c.PrimaryKeys) > 0 { | |
pkColumns := make([]string, 0, len(c.PrimaryKeys)) | |
for _, pk := range c.PrimaryKeys { | |
pkColumns = append(pkColumns, fmt.Sprintf("`%s`", pk)) | |
} | |
columns = append(columns, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(pkColumns, ", "))) | |
} | |
sb.WriteString(strings.Join(columns, ",\n")) | |
sb.WriteString("\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;") | |
return sb.String() | |
} | |
func (c *CSVToMySQLConverter) generateInsertStatements(file *os.File, headers []string, columnTypes []string) (string, error) { | |
file.Seek(0, 0) | |
reader := csv.NewReader(file) | |
reader.Comma = c.Delimiter | |
if c.HasHeader { | |
reader.Read() | |
} | |
var sb strings.Builder | |
var batchRows []string | |
rowCount := 0 | |
for { | |
record, err := reader.Read() | |
if err == io.EOF { | |
break | |
} | |
if err != nil { | |
return "", fmt.Errorf("error reading record: %w", err) | |
} | |
if len(record) != len(headers) { | |
continue | |
} | |
values := make([]string, 0, len(headers)) | |
for i, value := range record { | |
if columnTypes[i] == "SKIP" { | |
continue | |
} | |
value = strings.TrimSpace(value) | |
if value == "" || strings.EqualFold(value, c.NullString) { | |
values = append(values, "NULL") | |
continue | |
} | |
escaped := strings.ReplaceAll(value, "'", "''") | |
escaped = strings.ReplaceAll(escaped, "\\", "\\\\") | |
if columnTypes[i] == "INT" || columnTypes[i] == "DECIMAL(20,6)" { | |
if _, err := strconv.ParseFloat(value, 64); err == nil { | |
values = append(values, escaped) | |
continue | |
} | |
} | |
values = append(values, fmt.Sprintf("'%s'", escaped)) | |
} | |
if c.BatchInsert { | |
batchRows = append(batchRows, fmt.Sprintf("(%s)", strings.Join(values, ", "))) | |
if len(batchRows) >= c.BatchSize { | |
sb.WriteString(c.formatBatchInsert(headers, columnTypes, batchRows)) | |
batchRows = batchRows[:0] | |
} | |
} else { | |
sb.WriteString(fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s);\n", | |
c.TableName, | |
c.formatInsertColumns(headers, columnTypes), | |
strings.Join(values, ", "))) | |
} | |
rowCount++ | |
} | |
if len(batchRows) > 0 { | |
sb.WriteString(c.formatBatchInsert(headers, columnTypes, batchRows)) | |
} | |
return sb.String(), nil | |
} | |
func (c *CSVToMySQLConverter) formatInsertColumns(headers []string, columnTypes []string) string { | |
var cols []string | |
for i, h := range headers { | |
if columnTypes[i] != "SKIP" { | |
cols = append(cols, fmt.Sprintf("`%s`", h)) | |
} | |
} | |
return strings.Join(cols, ", ") | |
} | |
func (c *CSVToMySQLConverter) formatBatchInsert(headers []string, columnTypes []string, rows []string) string { | |
return fmt.Sprintf("INSERT INTO `%s` (%s) VALUES\n%s;\n", | |
c.TableName, | |
c.formatInsertColumns(headers, columnTypes), | |
strings.Join(rows, ",\n")) | |
} | |
func isInteger(s string) bool { | |
_, err := strconv.ParseInt(s, 10, 64) | |
return err == nil | |
} | |
func isDecimal(s string) bool { | |
_, err := strconv.ParseFloat(s, 64) | |
return err == nil | |
} | |
func isDate(s string) bool { | |
patterns := []string{ | |
`^\d{4}-\d{2}-\d{2}$`, | |
`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$`, | |
`^\d{2}/\d{2}/\d{4}$`, | |
`^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z?$`, | |
} | |
for _, pattern := range patterns { | |
matched, _ := regexp.MatchString(pattern, s) | |
if matched { | |
return true | |
} | |
} | |
return false | |
} | |
func main() { | |
converter := NewCSVToMySQLConverter() | |
converter.InputFile = "data.csv" | |
converter.TableName = "sales_data" | |
converter.PrimaryKeys = []string{"order_id"} | |
converter.ForceTypes = map[string]string{ | |
"order_id": "INT AUTO_INCREMENT", | |
"price": "DECIMAL(10,2)", | |
} | |
converter.SkipColumns = map[string]bool{"internal_code": true} | |
createStmt, insertStmts, err := converter.Convert() | |
if err != nil { | |
log.Fatal(err) | |
} | |
fmt.Println("-- CREATE TABLE STATEMENT --") | |
fmt.Println(createStmt) | |
fmt.Println("\n-- INSERT STATEMENTS --") | |
fmt.Println(insertStmts) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment