多线程HTTP下载器的简单实现

· 3209字 · 7分钟

1. 缘起 🔗

今天逛V2EX发现一个比较有意思的帖子“多线程分段下载文件时,为什么不下载到同一个大文件中?而是要分别下载到单独的文件然后再合并。 - V2EX”,楼主认为一些多线程HTTP文件下载器把文件下载到多个小文件后再合并是没有必要的,我深以为然。

部分人的看法在于把文件下载到多个独立的小文件,这样可以避免部分文件分块下载错误的问题。然而实际上将文件分为小文件后合并的实现方式并不能解决文件部分下载失败的问题,该失败还是得失败重试,而这种观点暴露了部分人并不了解可以对文件Seek。相比于下载为多个小文件后合并,直接创建一个文件,对文件的不同偏移量进行写入不需要合并文件,提高了下载的效率,也不会带来并发写入的安全问题。为了验证多线程下载不需要下载到多个小文件后合并,我做了一个简单实现,也算是增强自己对IO的理解。

2. 多线程下载器的设计 🔗

2.1 整体思路 🔗

多线程下载的优势在于,HTTP服务器或者网站的网关可能会按照连接对文件下载速率做一定限制,如果忽略客户端的请求速率限制和网络通信链路的限制,我们可以通过多条连接下载提高整体下载的速率。那么我们如何实现多线程的下载呢?要实现HTTP多线程下载,就需要有三个机制:

  1. 在下载前得到待下载文件的整体大小

  2. 并发线程下载时可以指定下载文件的偏移量

  3. 多个线程随机写入到文件的指定位置

HTTP协议实现了前两个机制,而操作系统实现了第三个机制:

  1. 我们可以通过HEAD请求获取待下载资源的长度。服务器收到对资源的HEAD请求后,需要返回相应的GET请求的响应头,而不能返回Body信息。通过HEAD请求,我们可以得到待下载资源的字节数,这可以通过响应头的Content-Length字段得到,Content-Length响应头的语法如下:

    Content-Length: <length>
    
  2. 我们可以通过HTTP的Range请求头指定服务器返回待下载文件的某个分段,这也是HTTP断点续传的原理。如果服务器支持的话,会返回206 Partial Content 响应并携带该分块的数据,否则返回HTTP / 200 OK响应及整个文件数据。Range请求头的语法如下,unit通常是字节bytes

    Range: <unit>=<range-start>-
    Range: <unit>=<range-start>-<range-end>
    Range: <unit>=<range-start>-<range-end>, <range-start>-<range-end>
    Range: <unit>=<range-start>-<range-end>, <range-start>-<range-end>, <range-start>-<range-end>
    Range: <unit>=-<suffix-length>
    
  3. 操作系统提供了lseek(2)机制,可以实现对文件读写偏移量的重定位。当我们要在偏移量offset的位置写入数据时,可以先通过lseek将文件描述符的偏移量移动到指定位置,后面再写。POSIX的lseek(2)提供3个重定位方式,即SEEK_SET(绝对偏移量)、SEEK_CUR(当前偏移量+相对偏移量)、SEEK_END(文件大小+相对偏移量)。

2.2 并发安全的文件写入 🔗

下载为多个小文件后合并的实现方法比较偷懒,通过HEAD得到文件的总长度后,我们可以按照下载线程数N创建N个临时的小文件,之后创建N个并发worker将分块下载的数据写入到这几个文件中。为了避免创建多个文件并合并带来的IO性能损耗,我们可以考虑只创建一个文件,每个worker下载的文件偏移量都可以通过文件分段大小和worker的标号得到。如果文件的总长度为SIZE,有N个worker,那么每个worker的下载大小为SIZE / worker下取整,如果无法均分则最后一个标号的worker需要多下载一些。

那我们多个worker并发去写同一个文件会不会有并发安全问题呢?如果我们只打开文件一次,一个文件对象在多个线程共享,那么肯定会有并发安全问题,因为多个线程操作同一个文件对象,共享的是同样的文件描述符(File Descriptor, FD),所有可能导致文件的覆盖。如果只有一个文件对象,我们必须用一个互斥锁同步对文件的读写,每个线程对文件的一次写入都需要获得锁之后,再通过lseek(2)修改FD偏移量,之后执行一次write,之后再解锁。毫无疑问互斥锁访问同一个文件对象是不可接受的,频繁的加锁解锁会严重影响写入性能,不符合我们的预期。

好在现代操作系统允许在一个进程内多次打开同一个文件,得到多个具有独立偏移量的文件描述符,例如通过man 2 open我们可以看到Linux的man page写了这么一段话:

Each open() of a file creates a new open file description; thus, there may be multiple open file descriptions corresponding to a file inode.

也就是说我们可以在一个进程中的多个并发的worker中分别打开该临时文件,之后seek到各自的起始位置写入,这样可以不用对文件读写做同步,如下图所示。

multi-thread-downloader-design

2.3 稀疏文件的额外收益 🔗

大多数UNIX文件系统都支持稀疏文件(Sparse File),当我们将文件描述符的偏移量移动到文件尾部之后再执行一次写入,那么从程序员的视角上看之前文件的尾部到写入的字节之间多出了一段以0填充的文件区域,这段区域被称为文件空洞,并不占用任何磁盘空间(严谨来说大部分文件系统以块为单位分配空间,所以根据文件空洞所处的位置不同可能还会占用一定空间,但是占用量很少)。因此当我们的下载任务未完成时,实际上未下载的数据大部分是不占用实际的存储空间的,需要注意。

3. 多线程下载器的实现 🔗

我用Go写了一个简单的命令行程序,默认创建的并发worker的数量是逻辑CPU的个数,可以通过flag -n指定并发worker的数量。通过channel,可以同步每个worker的下载状态。然而,这是一种很简单的实现,实际上要写一个完善的并发下载器需要更多的代码,包括进行任务持久化和恢复、根据文件头内容解析文件名和错误处理等。

package main

import (
	"errors"
	"flag"
	"fmt"
	"io"
	"net/http"
	"os"
	"runtime"
	"strconv"
	"strings"

	"github.com/schollz/progressbar/v3"
)

/*
	1. Use HEAD method to get the file size(HTTP response header Content-Length)
	2. Create a temporary file to store the downloaded content
	3. Calculate the range of each slice for each worker
	4. Send a GET request with Range header to download the slice
	5. Write the downloaded content to the temporary file

	Note: To avoid concurrent write to the same file, we open the file in each worker.

	TODO: Add handling logic for the following cases:
	- The server sends a 429 Too Many Requests status code, slow down the speed
	- Custom rate limiter to control the download speed
	- Add a timeout for the HTTP request
	- Add a retry mechanism for the failed HTTP request
	- Add a checksum to verify the downloaded content
*/

// HTTP status code that may occur:
// 200 OK: we can only download the complete file
// 206 Partial Content: we can download the file slice separately
// 429 Too Many Requests: we need to slow down the download speed

const (
	ProgressBarHint = "Downloading task: "
)

var (
	// number of the concurrent downloading workers. If not specified, the count is set to the number of CPU cores
	numWorkers int

	// ErrDownloadingNotCompleted indicates the task finished incompletely
	ErrDownloadingNotCompleted = errors.New("downloading not completed")
)

// DownloadingTask represents an HTTP file downloading task
type DownloadingTask struct {
	// the URL of the file to download
	Url string

	// the file name to save the downloaded content
	FileName string

	// the temporary file name to store the unfinished downloaded content
	TempFileName string

	// the length of the total file in bytes, -1 means unknown file size
	Length int64

	// concurrent workers to download the file
	NumWorker int

	// the downloading status of each worker
	CompleteStatus []bool
}

func NewDownloadingTask(fileURL string, numWorker int) (*DownloadingTask, error) {
	// fetch the content length of the file
	resp, err := http.Head(fileURL)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()

	contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
	if err != nil {
		numWorker = 1
		contentLength = -1
	}

	// create a tmp file to store the downloaded content
	file, err := os.CreateTemp(".", "multi-downloader-*")
	if err != nil {
		return nil, err
	}
	defer file.Close()

	// create a downloading task and start downloading
	return &DownloadingTask{
		Url:            fileURL,
		Length:         contentLength,
		NumWorker:      numWorker,
		CompleteStatus: make([]bool, numWorker),
		TempFileName:   file.Name(),
		FileName:       parseFileName(fileURL),
	}, nil
}

// IsComplete checks if all the workers have finished downloading
func (t *DownloadingTask) IsComplete() bool {
	for _, status := range t.CompleteStatus {
		if !status {
			return false
		}
	}
	return true
}

// Start starts the downloading task by assigning the downloading task to all the workers.
func (t *DownloadingTask) Start() error {
	// use a channel to receive the downloading status for each slice
	statusStream := make(chan TaskSliceStatus)
	pb := progressbar.DefaultBytes(t.Length, ProgressBarHint) // use progress bar to show the downloading progress
	for i := 0; i < t.NumWorker; i++ {
		t.startSlice(i, statusStream, pb)
	}

	// collect the downloading status from all the workers
	for i := 0; i < t.NumWorker; i++ {
		status := <-statusStream
		if status.Err == nil {
			t.CompleteStatus[status.WorkerId] = true
		} else {
			fmt.Printf("failed to download slice %d: %s\n", status.WorkerId, status.Err.Error())
		}
	}

	if t.IsComplete() {
		return os.Rename(t.TempFileName, t.FileName)
	}
	return ErrDownloadingNotCompleted
}

// Resume resumes downloading unfinished slices of the task
func (t *DownloadingTask) Resume() error {
	statusStream := make(chan TaskSliceStatus)
	pb := progressbar.DefaultBytes(-1, ProgressBarHint) // use progress bar to show the downloading progress

	countMissingSlice := 0
	for i := 0; i < t.NumWorker; i++ {
		if !t.CompleteStatus[i] {
			countMissingSlice++
			t.startSlice(i, statusStream, pb)
		}
	}

	for i := 0; i < countMissingSlice; i++ {
		status := <-statusStream
		if status.Err == nil {
			t.CompleteStatus[status.WorkerId] = true
		} else {
			fmt.Printf("failed to download slice %d: %s\n", status.WorkerId, status.Err.Error())
		}
	}

	if t.IsComplete() {
		return nil
	}
	return ErrDownloadingNotCompleted
}

func (t *DownloadingTask) startSlice(sliceNumber int, statusStream chan<- TaskSliceStatus, pb io.Writer) {
	go func(workerId int) {
		statusStream <- TaskSliceStatus{
			WorkerId: workerId,
			Err:      downloadSlice(t, workerId, pb),
		}
	}(sliceNumber)
}

type TaskSliceStatus struct {
	WorkerId int
	Err      error
}

func downloadSlice(task *DownloadingTask, workerId int, pb io.Writer) error {
	// open the file in each worker, so we don't have to synchronize the file access
	// each worker's file descriptor has its own offset
	file, err := os.OpenFile(task.TempFileName, os.O_WRONLY, 0644)
	if err != nil {
		return err
	}
	defer file.Close()

	// calculate the range of the slice
	sliceLen := task.Length / int64(task.NumWorker)
	start := sliceLen * int64(workerId)
	end := start + sliceLen - 1
	if workerId == task.NumWorker-1 {
		end = task.Length - 1
	}
	if _, err = file.Seek(start, io.SeekStart); err != nil {
		return err
	}

	w := io.MultiWriter(file, pb)
	req, err := http.NewRequest("GET", task.Url, nil)
	if err != nil {
		return err
	}

	req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0")
	req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusPartialContent && len(task.CompleteStatus) > 1 {
		return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
	}
	_, err = io.Copy(w, resp.Body)
	return err
}

// parseFileName extracts the file name from the file URL
// by removing the query string and the path
func parseFileName(fileURL string) string {
	idx := strings.Index(fileURL, "?")
	if idx != -1 {
		fileURL = fileURL[:idx]
	}

	idx = strings.LastIndex(fileURL, "/")
	if idx != -1 {
		return fileURL[idx+1:]
	}
	return fileURL
}

func init() {
	flag.IntVar(&numWorkers, "n", runtime.NumCPU(), "number of multi-thread workers")
	flag.Parse()

	if len(os.Args) < 2 {
		fmt.Println("Usage: " + os.Args[0] + " <url>")
		os.Exit(1)
	}
}

func main() {
	task, err := NewDownloadingTask(flag.Args()[0], numWorkers)
	if err != nil {
		fmt.Println("failed to initialize a downloading task:", err.Error())
		os.Exit(1)
	}

	fmt.Printf("Using %d workers to download the file\n", task.NumWorker)
	if err := task.Start(); err != nil {
		fmt.Println("failed to download the file:", err.Error())
		if err = task.Resume(); err != nil {
			fmt.Println("failed to resume the downloading task:", err.Error())
			os.Exit(1)
		}
	}
}

运行结果如下图:

multi-thread-downloader-result

4. 参考文献 🔗

comments powered by Disqus