diff --git a/drivers/teldrive/types.go b/drivers/teldrive/types.go index 084f967e6..d75347b89 100644 --- a/drivers/teldrive/types.go +++ b/drivers/teldrive/types.go @@ -52,7 +52,7 @@ type chunkTask struct { fileName string chunkSize int64 reader io.ReadSeeker - ss stream.StreamSectionReaderIF + ss stream.StreamSectionReader } type CopyManager struct { diff --git a/internal/bootstrap/config.go b/internal/bootstrap/config.go index 74e218f8f..b1297228b 100644 --- a/internal/bootstrap/config.go +++ b/internal/bootstrap/config.go @@ -99,24 +99,41 @@ func InitConfig() { if conf.Conf.MaxConcurrency > 0 { net.DefaultConcurrencyLimit = &net.ConcurrencyLimit{Limit: conf.Conf.MaxConcurrency} } - if conf.Conf.MaxBufferLimit < 0 { - m, _ := mem.VirtualMemory() - if m != nil { - conf.MaxBufferLimit = max(int(float64(m.Total)*0.05), 4*utils.MB) - conf.MaxBufferLimit -= conf.MaxBufferLimit % utils.MB + + memStat, _ := mem.VirtualMemory() + if memStat != nil { + log.Infof("total memory: %dMB, available: %dMB", memStat.Total>>20, memStat.Available>>20) + } + if conf.Conf.MinFreeMemory < 1 { + if memStat != nil { + t := (memStat.Total >> 20) / 10 + conf.MinFreeMemory = max(16, min(t, 1024)) << 20 } else { - conf.MaxBufferLimit = 16 * utils.MB + conf.MinFreeMemory = 16 * utils.MB } } else { - conf.MaxBufferLimit = conf.Conf.MaxBufferLimit * utils.MB + conf.MinFreeMemory = max(16, uint64(conf.Conf.MinFreeMemory)) << 20 } - log.Infof("max buffer limit: %dMB", conf.MaxBufferLimit/utils.MB) - if conf.Conf.MmapThreshold > 0 { - conf.MmapThreshold = conf.Conf.MmapThreshold * utils.MB + log.Infof("min free memory: %dMB", conf.MinFreeMemory>>20) + + if conf.Conf.MaxBlockLimit < 0 { + if memStat != nil { + t := (memStat.Total >> 20) * 3 / 100 + conf.MaxBlockLimit = max(4, min(uint64(t), 64)) << 20 + } else { + conf.MaxBlockLimit = 16 * utils.MB + } + } else { + conf.MaxBlockLimit = uint64(conf.Conf.MaxBlockLimit) << 20 + } + log.Infof("max block limit: %dMB", conf.MaxBlockLimit>>20) + + if conf.Conf.CacheThreshold > 0 { + conf.CacheThreshold = uint64(conf.Conf.CacheThreshold) << 20 } else { - conf.MmapThreshold = 0 + conf.CacheThreshold = 0 } - log.Infof("mmap threshold: %dMB", conf.Conf.MmapThreshold) + log.Infof("cache threshold: %dMB", conf.CacheThreshold>>20) if len(conf.Conf.Log.Filter.Filters) == 0 { conf.Conf.Log.Filter.Enable = false diff --git a/internal/conf/config.go b/internal/conf/config.go index f347380d8..002e60ef8 100644 --- a/internal/conf/config.go +++ b/internal/conf/config.go @@ -120,8 +120,9 @@ type Config struct { DistDir string `json:"dist_dir"` Log LogConfig `json:"log" envPrefix:"LOG_"` DelayedStart int `json:"delayed_start" env:"DELAYED_START"` - MaxBufferLimit int `json:"max_buffer_limitMB" env:"MAX_BUFFER_LIMIT_MB"` - MmapThreshold int `json:"mmap_thresholdMB" env:"MMAP_THRESHOLD_MB"` + MinFreeMemory int `json:"min_free_memory" env:"MIN_FREE_MEMORY"` + MaxBlockLimit int `json:"max_block_limit" env:"MAX_BLOCK_LIMIT"` + CacheThreshold int `json:"cache_threshold" env:"CACHE_THRESHOLD"` MaxConnections int `json:"max_connections" env:"MAX_CONNECTIONS"` MaxConcurrency int `json:"max_concurrency" env:"MAX_CONCURRENCY"` TlsInsecureSkipVerify bool `json:"tls_insecure_skip_verify" env:"TLS_INSECURE_SKIP_VERIFY"` @@ -178,8 +179,8 @@ func DefaultConfig(dataDir string) *Config { }, }, }, - MaxBufferLimit: -1, - MmapThreshold: 4, + MaxBlockLimit: -1, + CacheThreshold: 4, MaxConnections: 0, MaxConcurrency: 64, TlsInsecureSkipVerify: false, diff --git a/internal/conf/var.go b/internal/conf/var.go index 972f69997..a2f33f9e3 100644 --- a/internal/conf/var.go +++ b/internal/conf/var.go @@ -25,10 +25,12 @@ var FilenameCharMap = make(map[string]string) var PrivacyReg []*regexp.Regexp var ( - // 单个Buffer最大限制 - MaxBufferLimit = 16 * 1024 * 1024 - // 超过该阈值的Buffer将使用 mmap 分配,可主动释放内存 - MmapThreshold = 4 * 1024 * 1024 + // 单次内存、磁盘缓存的扩容最大限制,超过该阈值将分多次扩充 + MaxBlockLimit uint64 = 16 * 1024 * 1024 + // 超过该阈值的数据流将使用HybridCache,可主动释放内存。 + CacheThreshold uint64 = 4 * 1024 * 1024 + // 最小空闲内存 + MinFreeMemory uint64 = 16 * 1024 * 1024 ) var ( RawIndexHtml string diff --git a/internal/mem/cache.go b/internal/mem/cache.go new file mode 100644 index 000000000..0ef0ca354 --- /dev/null +++ b/internal/mem/cache.go @@ -0,0 +1,230 @@ +package mem + +import ( + "errors" + "io" + "os" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/pkg/buffer" +) + +// 优先使用内存,失败后才使用文件。 +// 线程不安全 +type HybridCache struct { + mem LinearMemory + memOffset uint64 + file *os.File + fileOffset uint64 + blockSize uint64 +} + +func (hc *HybridCache) NextBlockWithSize(size uint64) buffer.Block { + if hc.file != nil { + if hc.fileOffset > 0 && hc.file.Truncate(int64(hc.fileOffset+size)) != nil { + return nil + } + base := hc.fileOffset + hc.fileOffset += size + fs := buffer.NewBlockAdapter( + io.NewOffsetWriter(hc.file, int64(base)), + io.NewSectionReader(hc.file, int64(base), int64(size)), + ) + return fs + } + all, err := hc.mem.Reallocate(uint64(hc.memOffset + size)) + if err == nil { + start := hc.memOffset + hc.memOffset += size + return buffer.NewByteBlock(all[start : start+size]) + } + if err := hc.initFileCache(); err != nil { + return nil + } + return hc.NextBlockWithSize(size) +} + +func (hc *HybridCache) NextBlock() buffer.Block { + return hc.NextBlockWithSize(hc.blockSize) +} + +// func (hc *HybridCache) GetBlockSize() uint64 { +// return hc.blockSize +// } + +func (hc *HybridCache) RollbackBlockWithSize(size uint64) { + if hc.fileOffset > size { + hc.fileOffset -= size + return + } + size -= hc.fileOffset + hc.fileOffset = 0 + if hc.memOffset > size { + hc.memOffset -= size + return + } + hc.memOffset = 0 +} + +func (hc *HybridCache) RollbackBlock() { + hc.RollbackBlockWithSize(hc.blockSize) +} + +func (hc *HybridCache) initFileCache() error { + f, err := os.CreateTemp(conf.Conf.TempDir, "file-*") + if err != nil { + return err + } + if err := f.Truncate(int64(hc.blockSize)); err != nil { + _, _ = f.Close(), os.Remove(f.Name()) + return err + } + hc.file = f + return nil +} + +func (hc *HybridCache) Close() error { + if hc.blockSize > 0 { + hc.blockSize = 0 + var err error + if hc.mem != nil { + err = hc.mem.Free() + hc.mem = nil + } + if hc.file != nil { + err = errors.Join(err, hc.file.Close(), os.Remove(hc.file.Name())) + hc.file = nil + } + return err + } + return nil +} + +func (hc *HybridCache) Size() int64 { + return int64(hc.memOffset + hc.fileOffset) +} + +func (hc *HybridCache) ReadAt(p []byte, off int64) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + if off < 0 || off >= hc.Size() { + return 0, io.EOF + } + if off < int64(hc.memOffset) { + all, err := hc.mem.Reallocate(min(hc.memOffset, uint64(off)+uint64(len(p)))) + if err != nil { + // 不可能失败 + panic(err) + } + n = copy(p, all[off:]) + if n == len(p) { + return n, nil + } + p = p[n:] + } + + off += int64(n) - int64(hc.memOffset) + canRead := int64(hc.fileOffset) - off + if canRead <= 0 { + return n, io.EOF + } + nn, err := hc.file.ReadAt(p[:min(len(p), int(canRead))], off) + return n + nn, err +} + +func (hc *HybridCache) WriteAt(p []byte, off int64) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + if off < 0 || off >= hc.Size() { + return 0, io.ErrShortWrite + } + + if off < int64(hc.memOffset) { + all, err := hc.mem.Reallocate(min(hc.memOffset, uint64(off)+uint64(len(p)))) + if err != nil { + // 不可能失败 + panic(err) + } + n = copy(all[off:], p) + if n == len(p) { + return n, nil + } + p = p[n:] + } + + off += int64(n) - int64(hc.memOffset) + canWrite := int64(hc.fileOffset) - off + if canWrite <= 0 { + return n, io.ErrShortWrite + } + nn, err := hc.file.WriteAt(p[:min(len(p), int(canWrite))], off) + return n + nn, err +} + +// 优先使用内存,失败后才使用文件 +// 线程不安全 +func NewHybridCache(blockSize, maxMemorySize uint64) (*HybridCache, error) { + var err error + if conf.CacheThreshold > 0 { + var m LinearMemory + m, err = NewGuardedMemory(blockSize, maxMemorySize) + if err == nil { + return &HybridCache{mem: m, blockSize: blockSize}, nil + } + } + hc := &HybridCache{blockSize: blockSize} + if err2 := hc.initFileCache(); err2 != nil { + return nil, errors.Join(err, err2) + } + return hc, nil +} + +var _ buffer.Block = (*HybridCache)(nil) + +type HybridCacheReader struct { + hc *HybridCache + offset int64 +} + +func NewHybridCacheReader(hc *HybridCache) *HybridCacheReader { + return &HybridCacheReader{hc: hc} +} + +func (hcr *HybridCacheReader) Size() int64 { + return hcr.hc.Size() +} + +func (hcr *HybridCacheReader) Read(p []byte) (n int, err error) { + n, err = hcr.hc.ReadAt(p, hcr.offset) + if n > 0 { + hcr.offset += int64(n) + } + return n, err +} + +func (hcr *HybridCacheReader) ReadAt(p []byte, off int64) (n int, err error) { + return hcr.hc.ReadAt(p, off) +} + +func (hcr *HybridCacheReader) Seek(offset int64, whence int) (int64, error) { + switch whence { + case io.SeekStart: + case io.SeekCurrent: + if offset == 0 { + return hcr.offset, nil + } + offset = hcr.offset + offset + case io.SeekEnd: + offset = hcr.Size() + offset + default: + return 0, errors.New("Seek: invalid whence") + } + + if offset < 0 || offset > hcr.Size() { + return 0, errors.New("Seek: invalid offset") + } + hcr.offset = offset + return offset, nil +} diff --git a/internal/mem/mem_other.go b/internal/mem/mem_other.go new file mode 100644 index 000000000..59a892ebc --- /dev/null +++ b/internal/mem/mem_other.go @@ -0,0 +1,25 @@ +//go:build !unix && !windows + +package mem // import "github.com/ncruces/go-sqlite3/internal/alloc" + +func NewMemory(cap, max uint64) (LinearMemory, error) { + return &sliceMemory{buf: make([]byte, 0, cap)}, nil +} + +type sliceMemory struct { + buf []byte +} + +func (b *sliceMemory) Free() error { + b.buf = nil + return nil +} + +func (b *sliceMemory) Reallocate(size uint64) ([]byte, error) { + if cap := uint64(cap(b.buf)); size > cap { + b.buf = append(b.buf[:cap], make([]byte, size-cap)...) + } else { + b.buf = b.buf[:size] + } + return b.buf, nil +} diff --git a/internal/mem/mem_unix.go b/internal/mem/mem_unix.go new file mode 100644 index 000000000..4307daa91 --- /dev/null +++ b/internal/mem/mem_unix.go @@ -0,0 +1,92 @@ +//go:build unix + +package mem // import "github.com/ncruces/go-sqlite3/internal/alloc" + +import ( + "math" + + "golang.org/x/sys/unix" +) + +func NewMemory(cap, max uint64) (LinearMemory, error) { + // Round up to the page size. + rnd := uint64(unix.Getpagesize() - 1) + res := (max + rnd) &^ rnd + + if res > math.MaxInt { + // This ensures int(res) overflows to a negative value, + // and unix.Mmap returns EINVAL. + res = math.MaxUint64 + } + + com := res + prot := unix.PROT_READ | unix.PROT_WRITE + if cap < max { // Commit memory only if cap=max. + com = 0 + prot = unix.PROT_NONE + } + + // Reserve res bytes of address space, to ensure we won't need to move it. + // A protected, private, anonymous mapping should not commit memory. + b, err := unix.Mmap(-1, 0, int(res), prot, unix.MAP_PRIVATE|unix.MAP_ANON) + if err != nil { + return nil, err + } + return &mmappedMemory{buf: b[:com]}, nil +} + +// The slice covers the entire mmapped memory: +// - len(buf) is the already committed memory, +// - cap(buf) is the reserved address space. +type mmappedMemory struct { + buf []byte + growCheck GrowCheck +} + +func (m *mmappedMemory) SetGrowCheck(c GrowCheck) { + m.growCheck = c +} + +func (m *mmappedMemory) Reallocate(size uint64) ([]byte, error) { + com := uint64(len(m.buf)) + res := uint64(cap(m.buf)) + if com < size { + if size <= res { + // Grow geometrically, round up to the page size. + rnd := uint64(unix.Getpagesize() - 1) + new := com + com>>3 + new = min(max(size, new), res) + new = (new + rnd) &^ rnd + + if m.growCheck != nil { + if err := m.growCheck(new - com); err != nil { + return nil, err + } + } + + // Commit additional memory up to new bytes. + err := unix.Mprotect(m.buf[com:new], unix.PROT_READ|unix.PROT_WRITE) + if err != nil { + return nil, err + } + + m.buf = m.buf[:new] // Update committed memory. + } else { + return nil, ErrNotEnoughMemory + } + } + // Limit returned capacity because bytes beyond + // len(m.buf) have not yet been committed. + return m.buf[:size:len(m.buf)], nil +} + +func (m *mmappedMemory) Free() error { + if m.buf != nil { + err := unix.Munmap(m.buf[:cap(m.buf)]) + if err != nil { + return err + } + m.buf = nil + } + return nil +} diff --git a/internal/mem/mem_windows.go b/internal/mem/mem_windows.go new file mode 100644 index 000000000..9db1a3db9 --- /dev/null +++ b/internal/mem/mem_windows.go @@ -0,0 +1,94 @@ +package mem // import "github.com/ncruces/go-sqlite3/internal/alloc" + +import ( + "math" + "unsafe" + + "golang.org/x/sys/windows" +) + +func NewMemory(cap, max uint64) (LinearMemory, error) { + // Round up to the page size. + rnd := uint64(windows.Getpagesize() - 1) + res := (max + rnd) &^ rnd + + if res > math.MaxInt { + // This ensures uintptr(res) overflows to a large value, + // and windows.VirtualAlloc returns an error. + res = math.MaxUint64 + } + + com := res + kind := windows.MEM_COMMIT + if cap < max { // Commit memory only if cap=max. + com = 0 + kind = windows.MEM_RESERVE + } + + // Reserve res bytes of address space, to ensure we won't need to move it. + r, err := windows.VirtualAlloc(0, uintptr(res), uint32(kind), windows.PAGE_READWRITE) + if err != nil { + return nil, err + } + + buf := unsafe.Slice((*byte)(unsafe.Pointer(r)), int(res)) + return &virtualMemory{addr: r, buf: buf[:com]}, nil +} + +// The slice covers the entire mmapped memory: +// - len(buf) is the already committed memory, +// - cap(buf) is the reserved address space. +type virtualMemory struct { + buf []byte + addr uintptr + growCheck GrowCheck +} + +func (m *virtualMemory) SetGrowCheck(c GrowCheck) { + m.growCheck = c +} + +func (m *virtualMemory) Reallocate(size uint64) ([]byte, error) { + com := uint64(len(m.buf)) + res := uint64(cap(m.buf)) + if com < size { + if size <= res { + // Grow geometrically, round up to the page size. + rnd := uint64(windows.Getpagesize() - 1) + new := com + com>>3 + new = min(max(size, new), res) + new = (new + rnd) &^ rnd + + if m.growCheck != nil { + if err := m.growCheck(new - com); err != nil { + return nil, err + } + } + + // Commit additional memory up to new bytes. + _, err := windows.VirtualAlloc(m.addr, uintptr(new), windows.MEM_COMMIT, windows.PAGE_READWRITE) + if err != nil { + return nil, err + } + + m.buf = m.buf[:new] // Update committed memory. + } else { + return nil, ErrNotEnoughMemory + } + } + // Limit returned capacity because bytes beyond + // len(m.buf) have not yet been committed. + return m.buf[:size:len(m.buf)], nil +} + +func (m *virtualMemory) Free() error { + if m.addr != 0 { + err := windows.VirtualFree(m.addr, 0, windows.MEM_RELEASE) + if err != nil { + return err + } + m.addr = 0 + m.buf = nil + } + return nil +} diff --git a/internal/mem/type.go b/internal/mem/type.go new file mode 100644 index 000000000..3ba8e35ca --- /dev/null +++ b/internal/mem/type.go @@ -0,0 +1,9 @@ +package mem + +type LinearMemory interface { + // 线程不安全 + Reallocate(size uint64) (all []byte, err error) + Free() error +} + +type GrowCheck func(growSize uint64) error diff --git a/internal/mem/utils.go b/internal/mem/utils.go new file mode 100644 index 000000000..f54cb3d10 --- /dev/null +++ b/internal/mem/utils.go @@ -0,0 +1,75 @@ +package mem + +import ( + "errors" + "fmt" + "runtime" + "sync/atomic" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/pkg/singleflight" + "github.com/shirou/gopsutil/v4/mem" +) + +var ErrNotEnoughMemory = errors.New("not enough memory") + +func MemoryGrowCheck(growSize uint64) error { + m, err, _ := singleflight.AnyGroup.Do("SafeMemory.GrowLimit", func() (any, error) { + m, err := mem.VirtualMemory() + if err != nil { + return nil, err + } + if m.Available < conf.MinFreeMemory { + return nil, ErrNotEnoughMemory + } + return m, nil + }) + if err != nil { + return err + } + memStat := m.(*mem.VirtualMemoryStat) + for { + available := atomic.LoadUint64(&memStat.Available) + if available < growSize || available-growSize < conf.MinFreeMemory { + return ErrNotEnoughMemory + } + if atomic.CompareAndSwapUint64(&memStat.Available, available, available-growSize) { + return nil + } + } +} + +func NewGuardedMemory(cap, max uint64) (m LinearMemory, err error) { + if err := MemoryGrowCheck(cap); err != nil { + return nil, err + } + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%w: %v", ErrNotEnoughMemory, r) + } + }() + m, err = NewMemory(cap, max) + if err != nil { + return nil, err + } + runtime.SetFinalizer(m, func(m LinearMemory) { + m.Free() + }) + if s, ok := m.(interface{ SetGrowCheck(GrowCheck) }); ok { + s.SetGrowCheck(MemoryGrowCheck) + } + return &guardedMemory{m}, nil +} + +type guardedMemory struct { + LinearMemory +} + +func (s *guardedMemory) Reallocate(size uint64) (all []byte, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%w: %v", ErrNotEnoughMemory, r) + } + }() + return s.LinearMemory.Reallocate(size) +} diff --git a/internal/model/file.go b/internal/model/file.go index 4ca7201e1..d6697cd0e 100644 --- a/internal/model/file.go +++ b/internal/model/file.go @@ -7,24 +7,26 @@ import ( // File is basic file level accessing interface type File interface { - io.Reader io.ReaderAt - io.Seeker + io.ReadSeeker +} +type FileWriter interface { + io.WriterAt + io.WriteSeeker } type FileCloser struct { File io.Closer } -func (f *FileCloser) Close() error { - var errs []error +func (f *FileCloser) Close() (err error) { if clr, ok := f.File.(io.Closer); ok { - errs = append(errs, clr.Close()) + err = clr.Close() } if f.Closer != nil { - errs = append(errs, f.Closer.Close()) + return errors.Join(err, f.Closer.Close()) } - return errors.Join(errs...) + return } // FileRangeReader 是对 RangeReaderIF 的轻量包装,表明由 RangeReaderIF.RangeRead diff --git a/internal/net/request.go b/internal/net/request.go index e1f045120..bc10d74b0 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -13,9 +13,10 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/mem" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/buffer" "github.com/OpenListTeam/OpenList/v4/pkg/utils" - "github.com/rclone/rclone/lib/mmap" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/aws/aws-sdk-go/aws/awsutil" @@ -86,8 +87,8 @@ func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readClo if impl.cfg.PartSize == 0 { impl.cfg.PartSize = DefaultDownloadPartSize } - if conf.MaxBufferLimit > 0 && impl.cfg.PartSize > conf.MaxBufferLimit { - impl.cfg.PartSize = conf.MaxBufferLimit + if impl.cfg.PartSize > int(conf.MaxBlockLimit) { + impl.cfg.PartSize = int(conf.MaxBlockLimit) } if impl.cfg.HttpClient == nil { impl.cfg.HttpClient = DefaultHttpRequestFunc @@ -109,7 +110,7 @@ type downloader struct { m sync.Mutex nextChunk int //next chunk id - bufs []*Buf + bufs []*buffer.PipeBuffer written int64 //total bytes of file downloaded from remote err error @@ -119,6 +120,8 @@ type downloader struct { maxPos int64 m2 sync.Mutex readingID int // 正在被读取的id + + hc *mem.HybridCache } type ConcurrencyLimit struct { @@ -202,15 +205,25 @@ func (d *downloader) download() (io.ReadCloser, error) { d.pos = d.params.Range.Start d.maxPos = d.params.Range.Start + d.params.Range.Length d.concurrency = d.cfg.Concurrency - _ = d.sendChunkTask(true) - var rc io.ReadCloser = NewMultiReadCloser(d.bufs[0], d.interrupt, d.finishBuf) + if d.params.Range.Length > int64(conf.CacheThreshold) { + d.hc, d.err = mem.NewHybridCache(uint64(d.cfg.PartSize), uint64(d.params.Range.Length)) + } + if d.err == nil { + d.err = d.sendChunkTask(true) + } + if d.err != nil { + d.concurrencyFinish() + return nil, d.interrupt() + } + + var rc io.ReadCloser = newMultiReadCloser(d.bufs[0], d.interrupt, d.finishBuf) // Return error return rc, d.err } -func (d *downloader) sendChunkTask(newConcurrency bool) error { +func (d *downloader) sendChunkTask(newConcurrency bool) (err error) { d.m.Lock() defer d.m.Unlock() isNewBuf := d.concurrency > 0 @@ -227,12 +240,26 @@ func (d *downloader) sendChunkTask(newConcurrency bool) error { go d.downloadPart() } - var buf *Buf + var br *buffer.PipeBuffer if isNewBuf { - buf = NewBuf(d.ctx, d.cfg.PartSize) - d.bufs = append(d.bufs, buf) + var b buffer.Block + if d.hc != nil { + b = d.hc.NextBlock() + if b == nil { + return fmt.Errorf("failed to create new buffer section") + } + } else { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic in creating new buffer section: %v", r) + } + }() + b = buffer.NewByteBlock(make([]byte, d.cfg.PartSize)) + } + br = buffer.NewPipeBuffer(d.ctx, b) + d.bufs = append(d.bufs, br) } else { - buf = d.getBuf(d.nextChunk) + br = d.getBuf(d.nextChunk) } if d.pos < d.maxPos { @@ -256,7 +283,7 @@ func (d *downloader) sendChunkTask(newConcurrency bool) error { finalSize += firstSize - minSize } } - err := buf.Reset(int(finalSize)) + err := br.Reset(int(finalSize)) if err != nil { return err } @@ -264,7 +291,7 @@ func (d *downloader) sendChunkTask(newConcurrency bool) error { start: d.pos, size: finalSize, id: d.nextChunk, - buf: buf, + buf: br, newConcurrency: newConcurrency, } @@ -286,24 +313,30 @@ func (d *downloader) interrupt() error { err := fmt.Errorf("interrupted") d.err = err } - close(d.chunkChannel) if d.bufs != nil { - d.cancel(err) for _, buf := range d.bufs { buf.Close() } d.bufs = nil + } + if d.cancel != nil { + d.cancel(err) + close(d.chunkChannel) + if d.hc != nil { + d.hc.Close() + } if d.concurrency > 0 { d.concurrency = -d.concurrency } log.Debugf("maxConcurrency:%d", d.cfg.Concurrency+d.concurrency) + d.cancel = nil } return err } -func (d *downloader) getBuf(id int) (b *Buf) { +func (d *downloader) getBuf(id int) *buffer.PipeBuffer { return d.bufs[id%len(d.bufs)] } -func (d *downloader) finishBuf(id int) (isLast bool, nextBuf *Buf) { +func (d *downloader) finishBuf(id int) (isLast bool, nextBuf *buffer.PipeBuffer) { id++ if id >= d.maxPart { return true, nil @@ -548,7 +581,7 @@ func (d *downloader) setErr(e error) { type chunk struct { start int64 size int64 - buf *Buf + buf *buffer.PipeBuffer id int newConcurrency bool @@ -598,185 +631,36 @@ func (e *errNeedRetry) Unwrap() error { return e.err } -type MultiReadCloser struct { - cfg *cfg - closer closerFunc - finish finishBufFUnc -} - -type cfg struct { +type multiReadCloser struct { rPos int //current reader position, start from 0 - curBuf *Buf + curBuf *buffer.PipeBuffer + finish finishBufFUnc + utils.CloseFunc } -type closerFunc func() error -type finishBufFUnc func(id int) (isLast bool, buf *Buf) +type finishBufFUnc func(id int) (isLast bool, buf *buffer.PipeBuffer) -// NewMultiReadCloser to save memory, we re-use limited Buf, and feed data to Read() -func NewMultiReadCloser(buf *Buf, c closerFunc, fb finishBufFUnc) *MultiReadCloser { - return &MultiReadCloser{closer: c, finish: fb, cfg: &cfg{curBuf: buf}} +// newMultiReadCloser to save memory, we re-use limited Buf, and feed data to Read() +func newMultiReadCloser(buf *buffer.PipeBuffer, c utils.CloseFunc, fb finishBufFUnc) *multiReadCloser { + return &multiReadCloser{CloseFunc: c, finish: fb, curBuf: buf} } -func (mr MultiReadCloser) Read(p []byte) (n int, err error) { - if mr.cfg.curBuf == nil { +func (mr *multiReadCloser) Read(p []byte) (n int, err error) { + if mr.curBuf == nil { return 0, io.EOF } - n, err = mr.cfg.curBuf.Read(p) - //log.Debugf("read_%d read current buffer, n=%d ,err=%+v", mr.cfg.rPos, n, err) + n, err = mr.curBuf.Read(p) + // log.Debugf("read_%d read current buffer, n=%d ,err=%+v", mr.rPos, n, err) if err == io.EOF { - log.Debugf("read_%d finished current buffer", mr.cfg.rPos) + log.Debugf("read_%d finished current buffer", mr.rPos) - isLast, next := mr.finish(mr.cfg.rPos) + isLast, next := mr.finish(mr.rPos) if isLast { return n, io.EOF } - mr.cfg.curBuf = next - mr.cfg.rPos++ + mr.curBuf = next + mr.rPos++ return n, nil } - if err == context.Canceled { - if e := context.Cause(mr.cfg.curBuf.ctx); e != nil { - err = e - } - } return n, err } -func (mr MultiReadCloser) Close() error { - return mr.closer() -} - -type Buf struct { - size int //expected size - ctx context.Context - offR int - offW int - rw sync.Mutex - buf []byte - mmap bool - - readSignal chan struct{} - readPending bool -} - -// NewBuf is a buffer that can have 1 read & 1 write at the same time. -// when read is faster write, immediately feed data to read after written -func NewBuf(ctx context.Context, maxSize int) *Buf { - br := &Buf{ - ctx: ctx, - size: maxSize, - readSignal: make(chan struct{}, 1), - } - if conf.MmapThreshold > 0 && maxSize >= conf.MmapThreshold { - m, err := mmap.Alloc(maxSize) - if err == nil { - br.buf = m - br.mmap = true - return br - } - } - br.buf = make([]byte, maxSize) - return br -} - -func (br *Buf) Reset(size int) error { - br.rw.Lock() - defer br.rw.Unlock() - if br.buf == nil { - return io.ErrClosedPipe - } - if size > cap(br.buf) { - return fmt.Errorf("reset size %d exceeds max size %d", size, cap(br.buf)) - } - br.size = size - br.offR = 0 - br.offW = 0 - return nil -} - -func (br *Buf) Read(p []byte) (int, error) { - if err := br.ctx.Err(); err != nil { - return 0, err - } - if len(p) == 0 { - return 0, nil - } - if br.offR >= br.size { - return 0, io.EOF - } - for { - br.rw.Lock() - if br.buf == nil { - br.rw.Unlock() - return 0, io.ErrClosedPipe - } - - if br.offW < br.offR { - br.rw.Unlock() - return 0, io.ErrUnexpectedEOF - } - if br.offW == br.offR { - br.readPending = true - br.rw.Unlock() - select { - case <-br.ctx.Done(): - return 0, br.ctx.Err() - case _, ok := <-br.readSignal: - if !ok { - return 0, io.ErrClosedPipe - } - continue - } - } - - n := copy(p, br.buf[br.offR:br.offW]) - br.offR += n - br.rw.Unlock() - if n < len(p) && br.offR >= br.size { - return n, io.EOF - } - return n, nil - } -} - -func (br *Buf) Write(p []byte) (int, error) { - if err := br.ctx.Err(); err != nil { - return 0, err - } - if len(p) == 0 { - return 0, nil - } - br.rw.Lock() - defer br.rw.Unlock() - if br.buf == nil { - return 0, io.ErrClosedPipe - } - if br.offW >= br.size { - return 0, io.ErrShortWrite - } - n := copy(br.buf[br.offW:], p[:min(br.size-br.offW, len(p))]) - br.offW += n - if br.readPending { - br.readPending = false - select { - case br.readSignal <- struct{}{}: - default: - } - } - if n < len(p) { - return n, io.ErrShortWrite - } - return n, nil -} - -func (br *Buf) Close() error { - br.rw.Lock() - defer br.rw.Unlock() - var err error - if br.mmap { - err = mmap.Free(br.buf) - br.mmap = false - } - br.buf = nil - close(br.readSignal) - return err -} diff --git a/internal/net/request_test.go b/internal/net/request_test.go index da16a3165..3ec425e95 100644 --- a/internal/net/request_test.go +++ b/internal/net/request_test.go @@ -16,8 +16,6 @@ import ( "github.com/sirupsen/logrus" ) -var buf22MB = make([]byte, 1024*1024*22) - func containsString(slice []string, val string) bool { for _, item := range slice { if item == val { @@ -27,18 +25,6 @@ func containsString(slice []string, val string) bool { return false } -func dummyHttpRequest(data []byte, p http_range.Range) io.ReadCloser { - - end := p.Start + p.Length - 1 - - if end >= int64(len(data)) { - end = int64(len(data)) - } - - bodyBytes := data[p.Start:end] - return io.NopCloser(bytes.NewReader(bodyBytes)) -} - func TestDownloadOrder(t *testing.T) { buff := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} downloader, invocations, ranges := newDownloadRangeClient(buff) diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 4c8238100..97d93e863 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -1,6 +1,7 @@ package stream import ( + "bytes" "context" "errors" "fmt" @@ -10,11 +11,11 @@ import ( "sync" "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/mem" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/pkg/buffer" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" - "github.com/rclone/rclone/lib/mmap" "go4.org/readerutil" ) @@ -28,8 +29,14 @@ type FileStream struct { Exist model.Obj //the file existed in the destination, we can reuse some info since we wil overwrite it utils.Closers size int64 - peekBuff *buffer.Reader oriReader io.Reader // the original reader, used for caching + hc *mem.HybridCache + peek peek +} + +type peek interface { + model.File + Size() int64 } func (f *FileStream) GetSize() int64 { @@ -51,15 +58,6 @@ func (f *FileStream) IsForceStreamUpload() bool { return f.ForceStreamUpload } -func (f *FileStream) Close() error { - if f.peekBuff != nil { - f.peekBuff.Reset() - f.oriReader = nil - f.peekBuff = nil - } - return f.Closers.Close() -} - func (f *FileStream) GetExist() model.Obj { return f.Exist } @@ -101,79 +99,57 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ } reader := f.Reader - if f.peekBuff != nil { - f.peekBuff.Seek(0, io.SeekStart) + if f.peek != nil { + f.peek.Seek(0, io.SeekStart) if writer != nil { - _, err := utils.CopyWithBuffer(writer, f.peekBuff) + _, err := utils.CopyWithBuffer(writer, f.peek) if err != nil { return nil, err } - f.peekBuff.Seek(0, io.SeekStart) + f.peek.Seek(0, io.SeekStart) } reader = f.oriReader } if writer != nil { reader = io.TeeReader(reader, writer) } + + // 如果文件大小未知,直接缓存到磁盘 if f.GetSize() < 0 { - if f.peekBuff == nil { - f.peekBuff = &buffer.Reader{} - } // 检查是否有数据 buf := []byte{0} n, err := io.ReadFull(reader, buf) - if n > 0 { - f.peekBuff.Append(buf[:n]) - } - if err == io.ErrUnexpectedEOF { - f.size = f.peekBuff.Size() - f.Reader = f.peekBuff - return f.peekBuff, nil + br := bytes.NewReader(buf[:n]) + if err == io.ErrUnexpectedEOF || err == io.EOF { + f.size = br.Size() + f.Reader = br + return br, nil } else if err != nil { return nil, err } - if conf.MaxBufferLimit-n > conf.MmapThreshold && conf.MmapThreshold > 0 { - m, err := mmap.Alloc(conf.MaxBufferLimit - n) - if err == nil { - f.Add(utils.CloseFunc(func() error { - return mmap.Free(m) - })) - n, err = io.ReadFull(reader, m) - if n > 0 { - f.peekBuff.Append(m[:n]) - } - if err == io.ErrUnexpectedEOF { - f.size = f.peekBuff.Size() - f.Reader = f.peekBuff - return f.peekBuff, nil - } else if err != nil { - return nil, err - } - } - } - tmpF, err := utils.CreateTempFile(reader, 0) + tmpF, err := utils.CreateTempFile(io.MultiReader(br, reader), 0) if err != nil { return nil, err } f.Add(utils.CloseFunc(func() error { return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name())) })) - peekF, err := buffer.NewPeekFile(f.peekBuff, tmpF) + stat, err := tmpF.Stat() if err != nil { return nil, err } - f.size = peekF.Size() - f.Reader = peekF - return peekF, nil + f.size = stat.Size() + f.Reader = tmpF + return tmpF, nil } if up != nil { cacheProgress := model.UpdateProgressWithRange(*up, 0, 50) *up = model.UpdateProgressWithRange(*up, 50, 100) size := f.GetSize() - if f.peekBuff != nil { - peekSize := f.peekBuff.Size() - cacheProgress(float64(peekSize) / float64(size) * 100) + if f.peek != nil { + peekSize := f.peek.Size() + // cacheProgress(float64(peekSize) / float64(size) * 100) size -= peekSize } reader = &ReaderUpdatingProgress{ @@ -185,7 +161,7 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ } } - if f.peekBuff != nil { + if f.oriReader != nil { f.oriReader = reader } else { f.Reader = reader @@ -225,63 +201,55 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { // 确保指定大小的数据被缓存 func (f *FileStream) cache(maxCacheSize int64) (model.File, error) { - if maxCacheSize > int64(conf.MaxBufferLimit) { - size := f.GetSize() - reader := f.Reader - if f.peekBuff != nil { - size -= f.peekBuff.Size() - reader = f.oriReader - } - tmpF, err := utils.CreateTempFile(reader, size) - if err != nil { - return nil, err - } - f.Add(utils.CloseFunc(func() error { - return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name())) - })) - if f.peekBuff != nil { - peekF, err := buffer.NewPeekFile(f.peekBuff, tmpF) + if f.peek == nil { + f.oriReader = f.Reader + if f.GetSize() > int64(conf.CacheThreshold) { + blockSize := min(f.GetSize(), int64(conf.MaxBlockLimit)) + hc, err := mem.NewHybridCache(uint64(blockSize), uint64(f.GetSize())) if err != nil { return nil, err } - f.Reader = peekF - return peekF, nil + f.hc = hc + f.peek = mem.NewHybridCacheReader(hc) + f.Reader = io.MultiReader(f.peek, f.oriReader) + f.Add(hc) + } else { + br := &buffer.Reader{} + f.peek = br + f.Reader = io.MultiReader(br, f.oriReader) + f.Add(br) + } + } + cacheSize := maxCacheSize - f.peek.Size() + if f.hc != nil { + cacheSize2 := cacheSize + for cacheSize > 0 { + blockSize := min(cacheSize, int64(conf.MaxBlockLimit)) + b := f.hc.NextBlockWithSize(uint64(blockSize)) + if b == nil { + return nil, fmt.Errorf("failed to get cache section") + } + n, err := utils.CopyWithBufferN(buffer.WriteAtSeekerOf(b), f.oriReader, blockSize) + if n != blockSize { + f.hc.RollbackBlockWithSize(uint64(blockSize - n)) + return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", cacheSize2, cacheSize2-cacheSize+n, err) + } + cacheSize -= n } - f.Reader = tmpF - return tmpF, nil - } - - if f.peekBuff == nil { - f.peekBuff = &buffer.Reader{} - f.oriReader = f.Reader - f.Reader = io.MultiReader(f.peekBuff, f.oriReader) - } - bufSize := maxCacheSize - f.peekBuff.Size() - if bufSize <= 0 { - return f.peekBuff, nil - } - var buf []byte - if conf.MmapThreshold > 0 && bufSize >= int64(conf.MmapThreshold) { - m, err := mmap.Alloc(int(bufSize)) - if err == nil { - f.Add(utils.CloseFunc(func() error { - return mmap.Free(m) - })) - buf = m + } else { + if cacheSize > 0 { + buf := make([]byte, cacheSize) + n, err := io.ReadFull(f.oriReader, buf) + if n != len(buf) { + return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", len(buf), n, err) + } + f.peek.(*buffer.Reader).Append(buf) } } - if buf == nil { - buf = make([]byte, bufSize) - } - n, err := io.ReadFull(f.oriReader, buf) - if bufSize != int64(n) { - return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err) - } - f.peekBuff.Append(buf) - if f.peekBuff.Size() >= f.GetSize() { - f.Reader = f.peekBuff + if f.peek.Size() >= f.GetSize() { + f.Reader = f.peek } - return f.peekBuff, nil + return f.peek, nil } var _ model.FileStreamer = (*SeekableStream)(nil) diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go index 9a81e7d41..1b1d8bd07 100644 --- a/internal/stream/stream_test.go +++ b/internal/stream/stream_test.go @@ -1,13 +1,16 @@ -package stream +package stream_test import ( "bytes" "errors" "fmt" "io" + "math" "testing" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" ) @@ -17,17 +20,19 @@ func TestFileStream_RangeRead(t *testing.T) { httpRange http_range.Range } buf := []byte("github.com/OpenListTeam/OpenList") - f := &FileStream{ + f := &stream.FileStream{ Obj: &model.Object{ Size: int64(len(buf)), }, Reader: io.NopCloser(bytes.NewReader(buf)), } + conf.CacheThreshold = 10 + conf.MaxBlockLimit = 15 tests := []struct { name string - f *FileStream + f *stream.FileStream args args - want func(f *FileStream, got io.Reader, err error) error + want func(f *stream.FileStream, got io.Reader, err error) error }{ { name: "range 11-12", @@ -35,7 +40,7 @@ func TestFileStream_RangeRead(t *testing.T) { args: args{ httpRange: http_range.Range{Start: 11, Length: 12}, }, - want: func(f *FileStream, got io.Reader, err error) error { + want: func(f *stream.FileStream, got io.Reader, err error) error { if f.GetFile() != nil { return errors.New("cached") } @@ -52,7 +57,7 @@ func TestFileStream_RangeRead(t *testing.T) { args: args{ httpRange: http_range.Range{Start: 11, Length: 21}, }, - want: func(f *FileStream, got io.Reader, err error) error { + want: func(f *stream.FileStream, got io.Reader, err error) error { if f.GetFile() == nil { return errors.New("not cached") } @@ -86,12 +91,14 @@ func TestFileStream_RangeRead(t *testing.T) { func TestFileStream_With_PreHash(t *testing.T) { buf := []byte("github.com/OpenListTeam/OpenList") - f := &FileStream{ + f := &stream.FileStream{ Obj: &model.Object{ Size: int64(len(buf)), }, Reader: io.NopCloser(bytes.NewReader(buf)), } + conf.CacheThreshold = 10 + conf.MaxBlockLimit = 15 const hashSize int64 = 20 reader, _ := f.RangeRead(http_range.Range{Start: 0, Length: hashSize}) @@ -99,7 +106,7 @@ func TestFileStream_With_PreHash(t *testing.T) { if preHash == "" { t.Error("preHash is empty") } - tmpF, fullHash, _ := CacheFullAndHash(f, nil, utils.SHA1) + tmpF, fullHash, _ := stream.CacheFullAndHash(f, nil, utils.SHA1) fmt.Println(fullHash) fileFullHash, _ := utils.HashFile(utils.SHA1, tmpF) fmt.Println(fileFullHash) @@ -107,3 +114,47 @@ func TestFileStream_With_PreHash(t *testing.T) { t.Errorf("fullHash and fileFullHash should match: fullHash=%s fileFullHash=%s", fullHash, fileFullHash) } } + +func TestStreamSectionReader(t *testing.T) { + buf := make([]byte, 8<<10) + for i := range len(buf) { + buf[i] = byte(i % 256) + } + f := &stream.FileStream{ + Obj: &model.Object{ + Size: int64(len(buf)), + }, + Reader: io.NopCloser(bytes.NewReader(buf)), + } + conf.CacheThreshold = 1 + conf.MaxBlockLimit = 2 << 10 + partSize := 3 << 10 + ss, err := stream.NewStreamSectionReader(f, partSize, nil) + if err != nil { + t.Errorf("NewStreamSectionReader() error = %v", err) + } + conf.Conf = &conf.Config{} + for i := 0; i < len(buf); i += partSize { + length := partSize + if i+length > len(buf) { + length = len(buf) - i + } + rs, err := ss.GetSectionReader(int64(i), int64(length)) + if err != nil { + t.Errorf("StreamSectionReader.GetSectionReader() error = %v", err) + } + b1, err := io.ReadAll(rs) + if err != nil { + t.Errorf("StreamSectionReader.Read() error = %v", err) + } + rs.Seek(1, io.SeekStart) + b2, _ := io.ReadAll(rs) + if !bytes.Equal(b1[1:], b2) { + t.Errorf("StreamSectionReader.Read() = %s, want %s", b1[1:], b2) + } + if !bytes.Equal(buf[i:i+length], b1) { + t.Errorf("StreamSectionReader.Read() = %s, want %s", b1, buf[i:i+length]) + } + conf.MinFreeMemory = math.MaxUint64 + } +} diff --git a/internal/stream/util.go b/internal/stream/util.go index 6aa3dda5d..856aff91d 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -8,16 +8,16 @@ import ( "fmt" "io" "net/http" - "os" "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/mem" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/net" + "github.com/OpenListTeam/OpenList/v4/pkg/buffer" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/pool" "github.com/OpenListTeam/OpenList/v4/pkg/utils" - "github.com/rclone/rclone/lib/mmap" log "github.com/sirupsen/logrus" ) @@ -174,7 +174,7 @@ func CacheFullAndHash(stream model.FileStreamer, up *model.UpdateProgress, hashT return tmpF, hex.EncodeToString(h.Sum(nil)), nil } -type StreamSectionReaderIF interface { +type StreamSectionReader interface { // 线程不安全 GetSectionReader(off, length int64) (io.ReadSeeker, error) FreeSectionReader(sr io.ReadSeeker) @@ -182,71 +182,30 @@ type StreamSectionReaderIF interface { DiscardSection(off int64, length int64) error } -func NewStreamSectionReader(file model.FileStreamer, maxBufferSize int, up *model.UpdateProgress) (StreamSectionReaderIF, error) { +func NewStreamSectionReader(file model.FileStreamer, maxBufferSize int, up *model.UpdateProgress) (StreamSectionReader, error) { if file.GetFile() != nil { return &cachedSectionReader{file.GetFile()}, nil } maxBufferSize = min(maxBufferSize, int(file.GetSize())) - if maxBufferSize > conf.MaxBufferLimit { - f, err := os.CreateTemp(conf.Conf.TempDir, "file-*") - if err != nil { - return nil, err - } - - if f.Truncate(file.GetSize()) != nil { - // fallback to full cache - _, _ = f.Close(), os.Remove(f.Name()) - cache, err := file.CacheFullAndWriter(up, nil) - if err != nil { - return nil, err - } - return &cachedSectionReader{cache}, nil - } - - ss := &fileSectionReader{file: file, temp: f} - ss.bufPool = &pool.Pool[*offsetWriterWithBase]{ - New: func() *offsetWriterWithBase { - base := ss.tempOffset - ss.tempOffset += int64(maxBufferSize) - return &offsetWriterWithBase{io.NewOffsetWriter(ss.temp, base), base} - }, - } - file.Add(utils.CloseFunc(func() error { - ss.bufPool.Reset() - return errors.Join(ss.temp.Close(), os.Remove(ss.temp.Name())) - })) - return ss, nil - } - - ss := &directSectionReader{file: file} - if conf.MmapThreshold > 0 && maxBufferSize >= conf.MmapThreshold { - ss.bufPool = &pool.Pool[[]byte]{ - New: func() []byte { - buf, err := mmap.Alloc(maxBufferSize) - if err == nil { - file.Add(utils.CloseFunc(func() error { - return mmap.Free(buf) - })) - } else { - buf = make([]byte, maxBufferSize) - } - return buf - }, - } - } else { - ss.bufPool = &pool.Pool[[]byte]{ + if file.GetSize() <= int64(conf.CacheThreshold) { + bufPool := &pool.Pool[[]byte]{ New: func() []byte { return make([]byte, maxBufferSize) }, } + file.Add(bufPool) + ss := &byteSectionReader{file: file, bufPool: bufPool} + return ss, nil } - file.Add(utils.CloseFunc(func() error { - ss.bufPool.Reset() - return nil - })) - return ss, nil + blockSize := min(uint64(maxBufferSize), conf.MaxBlockLimit) + hc, err := mem.NewHybridCache(blockSize, uint64(file.GetSize())) + if err != nil { + return nil, err + } + file.Add(hc) + return &hybridSectionReader{file: file, hc: hc}, nil } type cachedSectionReader struct { @@ -261,21 +220,14 @@ func (s *cachedSectionReader) GetSectionReader(off, length int64) (io.ReadSeeker } func (*cachedSectionReader) FreeSectionReader(sr io.ReadSeeker) {} -type fileSectionReader struct { +type byteSectionReader struct { file model.FileStreamer fileOffset int64 - temp *os.File - tempOffset int64 - bufPool *pool.Pool[*offsetWriterWithBase] -} - -type offsetWriterWithBase struct { - *io.OffsetWriter - base int64 + bufPool *pool.Pool[[]byte] } // 线程不安全 -func (ss *fileSectionReader) DiscardSection(off int64, length int64) error { +func (ss *byteSectionReader) DiscardSection(off int64, length int64) error { if off != ss.fileOffset { return fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.fileOffset) } @@ -287,42 +239,42 @@ func (ss *fileSectionReader) DiscardSection(off int64, length int64) error { return nil } -type fileBufferSectionReader struct { +type bytesRefReadSeeker struct { io.ReadSeeker - fileBuf *offsetWriterWithBase + buf []byte } // 线程不安全 -func (ss *fileSectionReader) GetSectionReader(off, length int64) (io.ReadSeeker, error) { +func (ss *byteSectionReader) GetSectionReader(off, length int64) (io.ReadSeeker, error) { if off != ss.fileOffset { return nil, fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.fileOffset) } - fileBuf := ss.bufPool.Get() - _, _ = fileBuf.Seek(0, io.SeekStart) - n, err := utils.CopyWithBufferN(fileBuf, ss.file, length) - ss.fileOffset += n - if err != nil { + tempBuf := ss.bufPool.Get() + buf := tempBuf[:length] + n, err := io.ReadFull(ss.file, buf) + ss.fileOffset += int64(n) + if int64(n) != length { return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err) } - return &fileBufferSectionReader{io.NewSectionReader(ss.temp, fileBuf.base, length), fileBuf}, nil + return &bytesRefReadSeeker{bytes.NewReader(buf), buf}, nil } - -func (ss *fileSectionReader) FreeSectionReader(rs io.ReadSeeker) { - if sr, ok := rs.(*fileBufferSectionReader); ok { - ss.bufPool.Put(sr.fileBuf) - sr.fileBuf = nil +func (ss *byteSectionReader) FreeSectionReader(rs io.ReadSeeker) { + if sr, ok := rs.(*bytesRefReadSeeker); ok { + ss.bufPool.Put(sr.buf[0:cap(sr.buf)]) + sr.buf = nil sr.ReadSeeker = nil } } -type directSectionReader struct { +type hybridSectionReader struct { file model.FileStreamer fileOffset int64 - bufPool *pool.Pool[[]byte] + hc *mem.HybridCache + cache []buffer.Block } // 线程不安全 -func (ss *directSectionReader) DiscardSection(off int64, length int64) error { +func (ss *hybridSectionReader) DiscardSection(off int64, length int64) error { if off != ss.fileOffset { return fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.fileOffset) } @@ -334,29 +286,67 @@ func (ss *directSectionReader) DiscardSection(off int64, length int64) error { return nil } -type bufferSectionReader struct { +type blockRefReadSeeker struct { io.ReadSeeker - buf []byte + b buffer.Block } // 线程不安全 -func (ss *directSectionReader) GetSectionReader(off, length int64) (io.ReadSeeker, error) { +func (ss *hybridSectionReader) GetSectionReader(off, length int64) (io.ReadSeeker, error) { if off != ss.fileOffset { return nil, fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.fileOffset) } - tempBuf := ss.bufPool.Get() - buf := tempBuf[:length] - n, err := io.ReadFull(ss.file, buf) - ss.fileOffset += int64(n) - if int64(n) != length { - return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err) + b := ss.get() + if b == nil { + bOffset := int64(ss.hc.Size()) + cacheSize := length + for cacheSize > 0 { + blockSize := min(cacheSize, int64(conf.MaxBlockLimit)) + b2 := ss.hc.NextBlockWithSize(uint64(blockSize)) + if b2 == nil { + return nil, fmt.Errorf("failed to get cache section") + } + n, err := utils.CopyWithBufferN(buffer.WriteAtSeekerOf(b2), ss.file, blockSize) + ss.fileOffset += n + if n != blockSize { + return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, length-cacheSize+n, err) + } + cacheSize -= n + } + b = buffer.NewBlockAdapter( + io.NewOffsetWriter(ss.hc, bOffset), + io.NewSectionReader(ss.hc, bOffset, length), + ) + } else { + n, err := utils.CopyWithBufferN(buffer.WriteAtSeekerOf(b), ss.file, length) + ss.fileOffset += n + if n != length { + return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err) + } } - return &bufferSectionReader{bytes.NewReader(buf), buf}, nil + + if length == b.Size() { + return &blockRefReadSeeker{buffer.ReadAtSeekerOf(b), b}, nil + } + return &blockRefReadSeeker{io.NewSectionReader(b, 0, length), b}, nil } -func (ss *directSectionReader) FreeSectionReader(rs io.ReadSeeker) { - if sr, ok := rs.(*bufferSectionReader); ok { - ss.bufPool.Put(sr.buf[0:cap(sr.buf)]) - sr.buf = nil + +func (ss *hybridSectionReader) get() buffer.Block { + if len(ss.cache) > 0 { + b := ss.cache[len(ss.cache)-1] + ss.cache = ss.cache[:len(ss.cache)-1] + return b + } + return nil +} +func (ss *hybridSectionReader) put(b buffer.Block) { + ss.cache = append(ss.cache, b) +} + +func (ss *hybridSectionReader) FreeSectionReader(rs io.ReadSeeker) { + if sr, ok := rs.(*blockRefReadSeeker); ok { + ss.put(sr.b) + sr.b = nil sr.ReadSeeker = nil } } diff --git a/pkg/buffer/bytes.go b/pkg/buffer/bytes.go index 3e6cb5405..5e479c607 100644 --- a/pkg/buffer/bytes.go +++ b/pkg/buffer/bytes.go @@ -84,6 +84,11 @@ func (r *Reader) Reset() { r.offset = 0 } +func (r *Reader) Close() error { + r.Reset() + return nil +} + func NewReader(buf ...[]byte) *Reader { b := &Reader{ bufs: make([][]byte, 0, len(buf)), @@ -93,3 +98,39 @@ func NewReader(buf ...[]byte) *Reader { } return b } + +type byteBlock struct { + buf []byte +} + +func NewByteBlock(buf []byte) Block { + return &byteBlock{buf: buf} +} + +func (b *byteBlock) Size() int64 { + return int64(len(b.buf)) +} + +func (b *byteBlock) ReadAt(p []byte, off int64) (n int, err error) { + if len(b.buf) == 0 || off < 0 || off >= b.Size() { + return 0, io.EOF + } + n = copy(p, b.buf[off:]) + if n < len(p) { + err = io.EOF + } + return +} + +func (b *byteBlock) WriteAt(p []byte, off int64) (n int, err error) { + if len(b.buf) == 0 || off < 0 || off >= b.Size() { + return 0, io.ErrShortWrite + } + n = copy(b.buf[off:], p) + if n < len(p) { + err = io.ErrShortWrite + } + return +} + +var _ Block = (*byteBlock)(nil) diff --git a/pkg/buffer/pipe.go b/pkg/buffer/pipe.go new file mode 100644 index 000000000..194fd58cf --- /dev/null +++ b/pkg/buffer/pipe.go @@ -0,0 +1,157 @@ +package buffer + +import ( + "context" + "fmt" + "io" + "sync" +) + +type PipeBuffer struct { + limit int //expected size + ctx context.Context + offR int + offW int + rw sync.Mutex + block Block + + readSignal chan struct{} + readPending bool +} + +// NewPipeBuffer is a buffer that can have 1 read & 1 write at the same time. +// when read is faster write, immediately feed data to read after written +func NewPipeBuffer(ctx context.Context, block Block) *PipeBuffer { + br := &PipeBuffer{ + ctx: ctx, + limit: int(block.Size()), + readSignal: make(chan struct{}, 1), + block: block, + } + return br +} + +func (br *PipeBuffer) Read(p []byte) (int, error) { + if err := br.ctx.Err(); err != nil { + return 0, err + } + if len(p) == 0 { + return 0, nil + } + if br.offR >= br.limit { + return 0, io.EOF + } + + for { + br.rw.Lock() + if br.block == nil { + br.rw.Unlock() + return 0, io.ErrClosedPipe + } + + if br.offW == br.offR { + br.readPending = true + br.rw.Unlock() + select { + case <-br.ctx.Done(): + return 0, br.ctx.Err() + case _, ok := <-br.readSignal: + if !ok { + return 0, io.ErrClosedPipe + } + continue + } + } + break + } + + canRead := br.offW - br.offR + if canRead < 0 { + br.rw.Unlock() + return 0, io.ErrUnexpectedEOF + } + + off := br.offR + block := br.block + br.rw.Unlock() + + n, err := block.ReadAt(p[:min(len(p), canRead)], int64(off)) + + br.rw.Lock() + br.offR += n + br.rw.Unlock() + + if n < len(p) && br.offR >= br.limit { + return n, io.EOF + } + return n, err +} + +func (br *PipeBuffer) Write(p []byte) (int, error) { + if err := br.ctx.Err(); err != nil { + return 0, err + } + if len(p) == 0 { + return 0, nil + } + + br.rw.Lock() + if br.block == nil { + br.rw.Unlock() + return 0, io.ErrClosedPipe + } + + canWrite := br.limit - br.offW + if canWrite <= 0 { + br.rw.Unlock() + return 0, io.ErrShortWrite + } + + off := br.offW + block := br.block + br.rw.Unlock() + + n, err := block.WriteAt(p[:min(canWrite, len(p))], int64(off)) + + br.rw.Lock() + br.offW += n + if br.readPending { + br.readPending = false + select { + case br.readSignal <- struct{}{}: + default: + } + } + br.rw.Unlock() + + if n < len(p) && err == nil { + return n, io.ErrShortWrite + } + return n, err +} + +func (br *PipeBuffer) Reset(limit int) error { + br.rw.Lock() + defer br.rw.Unlock() + if br.block == nil { + return io.ErrClosedPipe + } + if int64(limit) > br.block.Size() { + return fmt.Errorf("reset limit %d exceeds max size %d", limit, br.block.Size()) + } + br.limit = limit + br.offR = 0 + br.offW = 0 + return nil +} + +func (br *PipeBuffer) Close() error { + br.rw.Lock() + defer br.rw.Unlock() + if br.block != nil { + br.block = nil + br.readPending = false + close(br.readSignal) + } + return nil +} diff --git a/pkg/buffer/type.go b/pkg/buffer/type.go new file mode 100644 index 000000000..e4d713216 --- /dev/null +++ b/pkg/buffer/type.go @@ -0,0 +1,22 @@ +package buffer + +import ( + "io" + + "github.com/OpenListTeam/OpenList/v4/internal/model" +) + +type Block interface { + io.ReaderAt + io.WriterAt + Size() int64 +} + +type WriteAtSeeker = model.FileWriter + +type ReadAtSeeker = model.File + +type SizedReadAtSeeker interface { + ReadAtSeeker + Size() int64 +} diff --git a/pkg/buffer/utils.go b/pkg/buffer/utils.go new file mode 100644 index 000000000..20358995c --- /dev/null +++ b/pkg/buffer/utils.go @@ -0,0 +1,42 @@ +package buffer + +import "io" + +type WriteAtSeekerProvider interface{ GetWriteAtSeeker() WriteAtSeeker } + +func WriteAtSeekerOf(b Block) WriteAtSeeker { + if p, ok := b.(WriteAtSeekerProvider); ok { + return p.GetWriteAtSeeker() + } + return io.NewOffsetWriter(b, 0) +} + +type ReadAtSeekerProvider interface{ GetReadAtSeeker() ReadAtSeeker } + +func ReadAtSeekerOf(b Block) ReadAtSeeker { + if p, ok := b.(ReadAtSeekerProvider); ok { + return p.GetReadAtSeeker() + } + return io.NewSectionReader(b, 0, b.Size()) +} + +type blockAdapter struct { + WriteAtSeeker + SizedReadAtSeeker +} + +func (b *blockAdapter) GetWriteAtSeeker() WriteAtSeeker { + return b.WriteAtSeeker +} + +func (b *blockAdapter) GetReadAtSeeker() ReadAtSeeker { + return b.SizedReadAtSeeker +} +func NewBlockAdapter(w WriteAtSeeker, r SizedReadAtSeeker) Block { + return &blockAdapter{ + WriteAtSeeker: w, + SizedReadAtSeeker: r, + } +} + +var _ Block = (*blockAdapter)(nil) diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index ce92cd1fc..6fe08246a 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -1,18 +1,12 @@ package pool -import "sync" - +// A simple object pool implementation. Not thread-safe. type Pool[T any] struct { - New func() T - MaxCap int - + New func() T cache []T - mu sync.Mutex } func (p *Pool[T]) Get() T { - p.mu.Lock() - defer p.mu.Unlock() if len(p.cache) == 0 { return p.New() } @@ -22,16 +16,15 @@ func (p *Pool[T]) Get() T { } func (p *Pool[T]) Put(item T) { - p.mu.Lock() - defer p.mu.Unlock() - if p.MaxCap == 0 || len(p.cache) < int(p.MaxCap) { - p.cache = append(p.cache, item) - } + p.cache = append(p.cache, item) } func (p *Pool[T]) Reset() { - p.mu.Lock() - defer p.mu.Unlock() clear(p.cache) p.cache = nil } + +func (p *Pool[T]) Close() error { + p.Reset() + return nil +}