|
| 1 | +package filedetect |
| 2 | + |
| 3 | +import ( |
| 4 | + "os" |
| 5 | + "strings" |
| 6 | + |
| 7 | + "github.com/h2non/filetype" |
| 8 | +) |
| 9 | + |
| 10 | +// FileFormatResult represents the file format detection result. |
| 11 | +type FileFormatResult struct { |
| 12 | + IsCorrect bool // Whether the format detection was successful and supported |
| 13 | + Extension string // The detected file extension |
| 14 | +} |
| 15 | + |
| 16 | +var readFileByLen func(file string, len int) ([]byte, error) |
| 17 | + |
| 18 | +func RegisterIoReader(onReadFile func(file string, len int) ([]byte, error)) { |
| 19 | + readFileByLen = onReadFile |
| 20 | +} |
| 21 | + |
| 22 | +func getBuffer(filepath string, len int) ([]byte, error) { |
| 23 | + // Read file header for detection |
| 24 | + var buffer []byte |
| 25 | + var err error |
| 26 | + |
| 27 | + if readFileByLen != nil { |
| 28 | + // Use registered reader function |
| 29 | + buffer, err = readFileByLen(filepath, len) // filetype needs max 262 bytes |
| 30 | + if err != nil { |
| 31 | + println(filepath, "header load failed1", err) |
| 32 | + |
| 33 | + } |
| 34 | + return buffer, err |
| 35 | + } else { |
| 36 | + // Use standard file reading |
| 37 | + file, err := os.Open(filepath) |
| 38 | + if err != nil { |
| 39 | + println(filepath, "header load failed2", err) |
| 40 | + return buffer, err |
| 41 | + } |
| 42 | + defer file.Close() |
| 43 | + |
| 44 | + buffer = make([]byte, len) |
| 45 | + n, err := file.Read(buffer) |
| 46 | + if err != nil && n == 0 { |
| 47 | + println(filepath, "header load failed3", err) |
| 48 | + return buffer, err |
| 49 | + } |
| 50 | + buffer = buffer[:n] |
| 51 | + } |
| 52 | + return buffer, err |
| 53 | +} |
| 54 | + |
| 55 | +// GetFileFormat gets the file format and its correctness |
| 56 | +func GetFileFormat(filepath string) *FileFormatResult { |
| 57 | + // Extract file extension from filepath |
| 58 | + fileExt := "" |
| 59 | + if lastDot := strings.LastIndex(filepath, "."); lastDot != -1 { |
| 60 | + fileExt = strings.ToLower(filepath[lastDot:]) |
| 61 | + } |
| 62 | + |
| 63 | + // Check svg |
| 64 | + buffer, err := getBuffer(filepath, 64) |
| 65 | + if err != nil { |
| 66 | + println(filepath, "header load failed", err) |
| 67 | + return &FileFormatResult{IsCorrect: false, Extension: ""} |
| 68 | + } |
| 69 | + |
| 70 | + if len(buffer) >= 5 { |
| 71 | + headerStr := strings.ToLower(string(buffer[:min(len(buffer), 64)])) |
| 72 | + if strings.Contains(headerStr, "<svg") || (strings.HasPrefix(headerStr, "<?xml") && strings.Contains(headerStr, "svg")) { |
| 73 | + isCorrect := strings.EqualFold(fileExt, ".svg") |
| 74 | + return &FileFormatResult{IsCorrect: isCorrect, Extension: ".svg"} |
| 75 | + } |
| 76 | + } |
| 77 | + |
| 78 | + buffer, err = getBuffer(filepath, 262) |
| 79 | + if err != nil { |
| 80 | + println(filepath, "header load failed", err) |
| 81 | + return &FileFormatResult{IsCorrect: false, Extension: ""} |
| 82 | + } |
| 83 | + // Detect file type using filetype library |
| 84 | + kind, err := filetype.Match(buffer) |
| 85 | + if err != nil || kind == filetype.Unknown { |
| 86 | + return &FileFormatResult{IsCorrect: false, Extension: ""} |
| 87 | + } |
| 88 | + |
| 89 | + // Get detected extension with dot prefix |
| 90 | + detectedExt := "." + kind.Extension |
| 91 | + |
| 92 | + // IsCorrect means: file extension matches detected format |
| 93 | + isCorrect := strings.EqualFold(fileExt, detectedExt) |
| 94 | + |
| 95 | + return &FileFormatResult{IsCorrect: isCorrect, Extension: detectedExt} |
| 96 | +} |
0 commit comments