Skip to content

Instantly share code, notes, and snippets.

@suntong
Last active April 9, 2025 22:19
Show Gist options
  • Save suntong/56ff467c1479f6d9e892c5e9064f1125 to your computer and use it in GitHub Desktop.
Save suntong/56ff467c1479f6d9e892c5e9064f1125 to your computer and use it in GitHub Desktop.
package main
import (
"encoding/csv"
"fmt"
"io"
"log"
"os"
"regexp"
"strconv"
"strings"
"unicode"
)
type Config struct {
InputFile string
TableName string
Delimiter rune
HasHeader bool
VarcharLength int
TextThreshold int
BatchInsert bool
BatchSize int
NullString string
PrimaryKeys []string
MaxSampleSize int
}
type CSVToMySQLConverter struct {
Config
ForceTypes map[string]string // column name -> MySQL type
SkipColumns map[string]bool // columns to skip
}
var (
sanitizeRegex = regexp.MustCompile(`[^a-zA-Z0-9_]+`)
leadingRegex = regexp.MustCompile(`^[^a-zA-Z_]`)
)
func NewCSVToMySQLConverter() *CSVToMySQLConverter {
return &CSVToMySQLConverter{
Config: Config{
Delimiter: ',',
HasHeader: true,
VarcharLength: 255,
TextThreshold: 500,
BatchInsert: true,
BatchSize: 100,
NullString: "NULL",
MaxSampleSize: 1000,
},
ForceTypes: make(map[string]string),
SkipColumns: make(map[string]bool),
}
}
func (c *CSVToMySQLConverter) Convert() (string, string, error) {
file, err := os.Open(c.InputFile)
if err != nil {
return "", "", fmt.Errorf("error opening file: %w", err)
}
defer file.Close()
reader := csv.NewReader(file)
reader.Comma = c.Delimiter
reader.TrimLeadingSpace = true
headers, err := c.readHeaders(reader)
if err != nil {
return "", "", fmt.Errorf("error reading headers: %w", err)
}
columnTypes, err := c.determineColumnTypes(reader, headers)
if err != nil {
return "", "", fmt.Errorf("error determining column types: %w", err)
}
createTable := c.generateCreateTable(headers, columnTypes)
inserts, err := c.generateInsertStatements(file, headers, columnTypes)
if err != nil {
return "", "", fmt.Errorf("error generating insert statements: %w", err)
}
return createTable, inserts, nil
}
func (c *CSVToMySQLConverter) readHeaders(reader *csv.Reader) ([]string, error) {
if c.HasHeader {
rawHeaders, err := reader.Read()
if err != nil {
return nil, fmt.Errorf("error reading header: %w", err)
}
headers := make([]string, len(rawHeaders))
for i, h := range rawHeaders {
headers[i] = c.sanitizeColumnName(h)
if headers[i] == "" {
headers[i] = fmt.Sprintf("column_%d", i+1)
}
}
return headers, nil
}
firstRow, err := reader.Read()
if err != nil {
return nil, fmt.Errorf("error reading first row: %w", err)
}
headers := make([]string, len(firstRow))
for i := range firstRow {
headers[i] = fmt.Sprintf("column_%d", i+1)
}
file, err := os.Open(c.InputFile)
if err != nil {
return nil, fmt.Errorf("error reopening file: %w", err)
}
defer file.Close()
reader = csv.NewReader(file)
reader.Comma = c.Delimiter
reader.TrimLeadingSpace = true
return headers, nil
}
func (c *CSVToMySQLConverter) sanitizeColumnName(name string) string {
name = sanitizeRegex.ReplaceAllString(strings.TrimSpace(name), "_")
name = strings.Trim(name, "_")
if leadingRegex.MatchString(name) {
name = "_" + name
}
// Convert to lowercase and replace spaces
var sb strings.Builder
for _, r := range strings.ToLower(name) {
if unicode.IsSpace(r) {
sb.WriteRune('_')
} else {
sb.WriteRune(r)
}
}
return sb.String()
}
func (c *CSVToMySQLConverter) determineColumnTypes(reader *csv.Reader, headers []string) ([]string, error) {
columnTypes := make([]string, len(headers))
for i := range headers {
if forcedType, ok := c.ForceTypes[headers[i]]; ok {
columnTypes[i] = forcedType
} else if c.SkipColumns[headers[i]] {
columnTypes[i] = "SKIP"
} else {
columnTypes[i] = "TEXT"
}
}
sampleCount := 0
for {
record, err := reader.Read()
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("error reading record: %w", err)
}
if len(record) != len(headers) {
log.Printf("Skipping row with %d columns (expected %d)", len(record), len(headers))
continue
}
for i, value := range record {
if columnTypes[i] == "SKIP" {
continue
}
value = strings.TrimSpace(value)
if value == "" || strings.EqualFold(value, c.NullString) {
continue
}
if _, ok := c.ForceTypes[headers[i]]; !ok {
columnTypes[i] = c.refineType(columnTypes[i], value)
}
}
sampleCount++
if sampleCount >= c.MaxSampleSize {
break
}
}
return columnTypes, nil
}
func (c *CSVToMySQLConverter) refineType(currentType, value string) string {
if isInteger(value) {
return "BIGINT"
}
if isDecimal(value) {
return "DECIMAL(20,6)"
}
if isDate(value) {
if len(value) > 10 {
return "DATETIME"
}
return "DATE"
}
length := len(value)
if length > c.TextThreshold {
return "TEXT"
}
if length > c.VarcharLength {
return fmt.Sprintf("VARCHAR(%d)", ((length/50)+1)*50)
}
return fmt.Sprintf("VARCHAR(%d)", c.VarcharLength)
}
func (c *CSVToMySQLConverter) generateCreateTable(headers []string, columnTypes []string) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("CREATE TABLE `%s` (\n", c.TableName))
columns := make([]string, 0, len(headers))
for i, header := range headers {
if columnTypes[i] == "SKIP" {
continue
}
columns = append(columns, fmt.Sprintf(" `%s` %s", header, columnTypes[i]))
}
if len(c.PrimaryKeys) > 0 {
pkColumns := make([]string, 0, len(c.PrimaryKeys))
for _, pk := range c.PrimaryKeys {
pkColumns = append(pkColumns, fmt.Sprintf("`%s`", pk))
}
columns = append(columns, fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(pkColumns, ", ")))
}
sb.WriteString(strings.Join(columns, ",\n"))
sb.WriteString("\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;")
return sb.String()
}
func (c *CSVToMySQLConverter) generateInsertStatements(file *os.File, headers []string, columnTypes []string) (string, error) {
file.Seek(0, 0)
reader := csv.NewReader(file)
reader.Comma = c.Delimiter
if c.HasHeader {
reader.Read()
}
var sb strings.Builder
var batchRows []string
rowCount := 0
for {
record, err := reader.Read()
if err == io.EOF {
break
}
if err != nil {
return "", fmt.Errorf("error reading record: %w", err)
}
if len(record) != len(headers) {
continue
}
values := make([]string, 0, len(headers))
for i, value := range record {
if columnTypes[i] == "SKIP" {
continue
}
value = strings.TrimSpace(value)
if value == "" || strings.EqualFold(value, c.NullString) {
values = append(values, "NULL")
continue
}
escaped := strings.ReplaceAll(value, "'", "''")
escaped = strings.ReplaceAll(escaped, "\\", "\\\\")
if columnTypes[i] == "INT" || columnTypes[i] == "DECIMAL(20,6)" {
if _, err := strconv.ParseFloat(value, 64); err == nil {
values = append(values, escaped)
continue
}
}
values = append(values, fmt.Sprintf("'%s'", escaped))
}
if c.BatchInsert {
batchRows = append(batchRows, fmt.Sprintf("(%s)", strings.Join(values, ", ")))
if len(batchRows) >= c.BatchSize {
sb.WriteString(c.formatBatchInsert(headers, columnTypes, batchRows))
batchRows = batchRows[:0]
}
} else {
sb.WriteString(fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s);\n",
c.TableName,
c.formatInsertColumns(headers, columnTypes),
strings.Join(values, ", ")))
}
rowCount++
}
if len(batchRows) > 0 {
sb.WriteString(c.formatBatchInsert(headers, columnTypes, batchRows))
}
return sb.String(), nil
}
func (c *CSVToMySQLConverter) formatInsertColumns(headers []string, columnTypes []string) string {
var cols []string
for i, h := range headers {
if columnTypes[i] != "SKIP" {
cols = append(cols, fmt.Sprintf("`%s`", h))
}
}
return strings.Join(cols, ", ")
}
func (c *CSVToMySQLConverter) formatBatchInsert(headers []string, columnTypes []string, rows []string) string {
return fmt.Sprintf("INSERT INTO `%s` (%s) VALUES\n%s;\n",
c.TableName,
c.formatInsertColumns(headers, columnTypes),
strings.Join(rows, ",\n"))
}
func isInteger(s string) bool {
_, err := strconv.ParseInt(s, 10, 64)
return err == nil
}
func isDecimal(s string) bool {
_, err := strconv.ParseFloat(s, 64)
return err == nil
}
func isDate(s string) bool {
patterns := []string{
`^\d{4}-\d{2}-\d{2}$`,
`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$`,
`^\d{2}/\d{2}/\d{4}$`,
`^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z?$`,
}
for _, pattern := range patterns {
matched, _ := regexp.MatchString(pattern, s)
if matched {
return true
}
}
return false
}
func main() {
converter := NewCSVToMySQLConverter()
converter.InputFile = "data.csv"
converter.TableName = "sales_data"
converter.PrimaryKeys = []string{"order_id"}
converter.ForceTypes = map[string]string{
"order_id": "INT AUTO_INCREMENT",
"price": "DECIMAL(10,2)",
}
converter.SkipColumns = map[string]bool{"internal_code": true}
createStmt, insertStmts, err := converter.Convert()
if err != nil {
log.Fatal(err)
}
fmt.Println("-- CREATE TABLE STATEMENT --")
fmt.Println(createStmt)
fmt.Println("\n-- INSERT STATEMENTS --")
fmt.Println(insertStmts)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment