packwiz/core/download.go
comp500 f3837af145 Completed download implementation for CF export
Added support for importing manual files and rehashing where necessary
Moved cache folder to "local" user folder
Cleaned up messages, saved index after importing
2022-05-21 03:40:00 +01:00

653 lines
19 KiB
Go

package core
import (
"encoding/json"
"errors"
"fmt"
"golang.org/x/exp/slices"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strings"
)
const DownloadCacheImportFolder = "import"
type DownloadSession interface {
GetManualDownloads() []ManualDownload
StartDownloads() chan CompletedDownload
SaveIndex() error
}
type CompletedDownload struct {
// File is only populated when the download is successful; points to the opened cache file
File *os.File
Mod *Mod
// Hashes is only populated when the download is successful; contains all stored hashes of the file
Hashes map[string]string
// Error indicates if/why downloading this file failed
Error error
// Warnings indicates messages to show to the user regarding this file (download was successful, but had a problem)
Warnings []error
}
type downloadSessionInternal struct {
cacheIndex CacheIndex
cacheFolder string
hashesToObtain []string
manualDownloads []ManualDownload
downloadTasks []downloadTask
foundManualDownloads []CompletedDownload
}
type downloadTask struct {
metaDownloaderData MetaDownloaderData
mod *Mod
url string
hashFormat string
hash string
}
func (d *downloadSessionInternal) GetManualDownloads() []ManualDownload {
return d.manualDownloads
}
func (d *downloadSessionInternal) StartDownloads() chan CompletedDownload {
downloads := make(chan CompletedDownload)
go func() {
for _, found := range d.foundManualDownloads {
downloads <- found
}
for _, task := range d.downloadTasks {
// Get handle for mod
cacheHandle := d.cacheIndex.GetHandleFromHash(task.hashFormat, task.hash)
if cacheHandle != nil {
download, err := reuseExistingFile(cacheHandle, d.hashesToObtain, task.mod)
if err != nil {
downloads <- CompletedDownload{
Error: err,
Mod: task.mod,
}
} else {
downloads <- download
}
continue
}
download, err := downloadNewFile(&task, d.cacheFolder, d.hashesToObtain, &d.cacheIndex)
if err != nil {
downloads <- CompletedDownload{
Error: err,
Mod: task.mod,
}
} else {
downloads <- download
}
}
close(downloads)
}()
return downloads
}
func (d *downloadSessionInternal) SaveIndex() error {
data, err := json.Marshal(d.cacheIndex)
if err != nil {
return fmt.Errorf("failed to serialise index: %w", err)
}
err = ioutil.WriteFile(filepath.Join(d.cacheFolder, "index.json"), data, 0644)
if err != nil {
return fmt.Errorf("failed to write index: %w", err)
}
return nil
}
func reuseExistingFile(cacheHandle *CacheIndexHandle, hashesToObtain []string, mod *Mod) (CompletedDownload, error) {
// Already stored; try using it!
file, err := cacheHandle.Open()
if err == nil {
remainingHashes := cacheHandle.GetRemainingHashes(hashesToObtain)
if len(remainingHashes) > 0 {
err = teeHashes(remainingHashes, cacheHandle.Hashes, io.Discard, file)
if err != nil {
_ = file.Close()
return CompletedDownload{}, fmt.Errorf("failed to read hashes of file %s from cache: %w", cacheHandle.Path(), err)
}
_, err := file.Seek(0, 0)
if err != nil {
_ = file.Close()
return CompletedDownload{}, fmt.Errorf("failed to seek file %s in cache: %w", cacheHandle.Path(), err)
}
cacheHandle.UpdateIndex()
}
return CompletedDownload{
File: file,
Mod: mod,
Hashes: cacheHandle.Hashes,
}, nil
} else {
return CompletedDownload{}, fmt.Errorf("failed to read file %s from cache: %w", cacheHandle.Path(), err)
}
}
func downloadNewFile(task *downloadTask, cacheFolder string, hashesToObtain []string, index *CacheIndex) (CompletedDownload, error) {
// Create temp file to download to
tempFile, err := ioutil.TempFile(filepath.Join(cacheFolder, "temp"), "download-tmp")
if err != nil {
return CompletedDownload{}, fmt.Errorf("failed to create temporary file for download: %w", err)
}
hashesToObtain, hashes := getHashListsForDownload(hashesToObtain, task.hashFormat, task.hash)
if len(hashesToObtain) > 0 {
var data io.ReadCloser
if task.url != "" {
resp, err := http.Get(task.url)
// TODO: content type, user-agent?
if err != nil {
return CompletedDownload{}, fmt.Errorf("failed to download %s: %w", task.url, err)
}
if resp.StatusCode != 200 {
_ = resp.Body.Close()
return CompletedDownload{}, fmt.Errorf("failed to download %s: invalid status code %v", task.url, resp.StatusCode)
}
data = resp.Body
} else {
data, err = task.metaDownloaderData.DownloadFile()
if err != nil {
return CompletedDownload{}, err
}
}
err = teeHashes(hashesToObtain, hashes, tempFile, data)
_ = data.Close()
if err != nil {
return CompletedDownload{}, fmt.Errorf("failed to download: %w", err)
}
}
// Create handle with calculated hashes
cacheHandle, alreadyExists := index.NewHandleFromHashes(hashes)
// Update index stored hashes
warnings := cacheHandle.UpdateIndex()
var file *os.File
if alreadyExists {
err = tempFile.Close()
if err != nil {
return CompletedDownload{}, fmt.Errorf("failed to close temporary file %s: %w", tempFile.Name(), err)
}
file, err = cacheHandle.Open()
if err != nil {
return CompletedDownload{}, fmt.Errorf("failed to read file %s from cache: %w", cacheHandle.Path(), err)
}
} else {
// Automatically closes tempFile
file, err = cacheHandle.CreateFromTemp(tempFile)
if err != nil {
_ = tempFile.Close()
return CompletedDownload{}, fmt.Errorf("failed to move file %s to cache: %w", cacheHandle.Path(), err)
}
}
return CompletedDownload{
File: file,
Mod: task.mod,
Hashes: hashes,
Warnings: warnings,
}, nil
}
func selectPreferredHash(hashes map[string]string) (currHashFormat string, currHash string) {
for _, hashFormat := range preferredHashList {
if hash, ok := hashes[hashFormat]; ok {
currHashFormat = hashFormat
currHash = hash
}
}
return
}
// getHashListsForDownload creates a hashes map with the given validate hash+format,
// ensures cacheHashFormat is in hashesToObtain (cloned+returned) and validateHashFormat isn't
func getHashListsForDownload(hashesToObtain []string, validateHashFormat string, validateHash string) ([]string, map[string]string) {
hashes := make(map[string]string)
hashes[validateHashFormat] = validateHash
cl := []string{cacheHashFormat}
for _, v := range hashesToObtain {
if v != validateHashFormat && v != cacheHashFormat {
cl = append(cl, v)
}
}
return cl, hashes
}
func teeHashes(hashesToObtain []string, hashes map[string]string,
dst io.Writer, src io.Reader) error {
// Select the best hash from the hashes map to validate against
validateHashFormat, validateHash := selectPreferredHash(hashes)
if validateHashFormat == "" {
return errors.New("failed to find preferred hash for file")
}
// Create writers for all the hashers
mainHasher, err := GetHashImpl(validateHashFormat)
if err != nil {
return fmt.Errorf("failed to get hash format %s", validateHashFormat)
}
hashers := make(map[string]HashStringer, len(hashesToObtain))
allWriters := make([]io.Writer, len(hashesToObtain))
for i, v := range hashesToObtain {
hashers[v], err = GetHashImpl(v)
if err != nil {
return fmt.Errorf("failed to get hash format %s", v)
}
allWriters[i] = hashers[v]
}
allWriters = append(allWriters, mainHasher, dst)
// Copy source to all writers (all hashers and dst)
w := io.MultiWriter(allWriters...)
_, err = io.Copy(w, src)
if err != nil {
return fmt.Errorf("failed to read file: %w", err)
}
calculatedHash := mainHasher.HashToString(mainHasher.Sum(nil))
// Check if the hash of the downloaded file matches the expected hash
if calculatedHash != validateHash {
return fmt.Errorf(
"%s hash of downloaded file does not match with expected hash!\n download hash: %s\n expected hash: %s\n",
validateHashFormat, calculatedHash, validateHash)
}
for hashFormat, v := range hashers {
hashes[hashFormat] = v.HashToString(v.Sum(nil))
}
return nil
}
const cacheHashFormat = "sha256"
type CacheIndex struct {
Version uint32
Hashes map[string][]string
cachePath string
nextHashIdx int
}
type CacheIndexHandle struct {
index *CacheIndex
hashIdx int
Hashes map[string]string
}
func (c *CacheIndex) getHashesMap(i int) map[string]string {
hashes := make(map[string]string)
for curHashFormat, hashList := range c.Hashes {
if i < len(hashList) && hashList[i] != "" {
hashes[curHashFormat] = hashList[i]
}
}
return hashes
}
func (c *CacheIndex) GetHandleFromHash(hashFormat string, hash string) *CacheIndexHandle {
storedHashFmtList, hasStoredHashFmt := c.Hashes[hashFormat]
if hasStoredHashFmt {
hashIdx := slices.Index(storedHashFmtList, hash)
if hashIdx > -1 {
return &CacheIndexHandle{
index: c,
hashIdx: hashIdx,
Hashes: c.getHashesMap(hashIdx),
}
}
}
return nil
}
// GetHandleFromHashForce looks up the given hash in the index; but will rehash any file without this hash format to
// obtain the necessary hash. Only use this for manually downloaded files, as it can rehash every file in the cache, which
// can be more time-consuming than just redownloading the file and noticing it is already in the index!
func (c *CacheIndex) GetHandleFromHashForce(hashFormat string, hash string) (*CacheIndexHandle, error) {
storedHashFmtList, hasStoredHashFmt := c.Hashes[hashFormat]
if hasStoredHashFmt {
// Ensure hash list is extended to the length of the cache hash format list
storedHashFmtList = append(storedHashFmtList, make([]string, len(c.Hashes[cacheHashFormat])-len(storedHashFmtList))...)
c.Hashes[hashFormat] = storedHashFmtList
// Rehash every file that doesn't have this hash with this hash
for hashIdx, curHash := range storedHashFmtList {
if curHash == hash {
return &CacheIndexHandle{
index: c,
hashIdx: hashIdx,
Hashes: c.getHashesMap(hashIdx),
}, nil
} else if curHash == "" {
var err error
storedHashFmtList[hashIdx], err = c.rehashFile(c.Hashes[cacheHashFormat][hashIdx], hashFormat)
if err != nil {
return nil, fmt.Errorf("failed to rehash %s: %w", c.Hashes[cacheHashFormat][hashIdx], err)
}
if storedHashFmtList[hashIdx] == hash {
return &CacheIndexHandle{
index: c,
hashIdx: hashIdx,
Hashes: c.getHashesMap(hashIdx),
}, nil
}
}
}
} else {
// Rehash every file with this hash
storedHashFmtList = make([]string, len(c.Hashes[cacheHashFormat]))
c.Hashes[hashFormat] = storedHashFmtList
for hashIdx, cacheHash := range c.Hashes[cacheHashFormat] {
var err error
storedHashFmtList[hashIdx], err = c.rehashFile(cacheHash, hashFormat)
if err != nil {
return nil, fmt.Errorf("failed to rehash %s: %w", cacheHash, err)
}
if storedHashFmtList[hashIdx] == hash {
return &CacheIndexHandle{
index: c,
hashIdx: hashIdx,
Hashes: c.getHashesMap(hashIdx),
}, nil
}
}
}
return nil, nil
}
func (c *CacheIndex) rehashFile(cacheHash string, hashFormat string) (string, error) {
file, err := os.Open(filepath.Join(c.cachePath, cacheHash[:2], cacheHash[2:]))
if err != nil {
return "", err
}
validateHasher, err := GetHashImpl(cacheHashFormat)
if err != nil {
return "", fmt.Errorf("failed to get hasher for rehash: %w", err)
}
rehashHasher, err := GetHashImpl(hashFormat)
if err != nil {
return "", fmt.Errorf("failed to get hasher for rehash: %w", err)
}
writer := io.MultiWriter(validateHasher, rehashHasher)
_, err = io.Copy(writer, file)
if err != nil {
return "", err
}
validateHash := validateHasher.HashToString(validateHasher.Sum(nil))
if cacheHash != validateHash {
return "", fmt.Errorf(
"%s hash of cached file does not match with expected hash!\n read hash: %s\n expected hash: %s\n",
cacheHashFormat, validateHash, cacheHash)
}
return rehashHasher.HashToString(rehashHasher.Sum(nil)), nil
}
func (c *CacheIndex) NewHandleFromHashes(hashes map[string]string) (*CacheIndexHandle, bool) {
for hashFormat, hash := range hashes {
handle := c.GetHandleFromHash(hashFormat, hash)
if handle != nil {
// Add hashes to handle
for hashFormat2, hash2 := range hashes {
handle.Hashes[hashFormat2] = hash2
}
return handle, true
}
}
i := c.nextHashIdx
c.nextHashIdx += 1
return &CacheIndexHandle{
index: c,
hashIdx: i,
Hashes: hashes,
}, false
}
func (c *CacheIndex) MoveImportFiles() error {
return filepath.Walk(filepath.Join(c.cachePath, DownloadCacheImportFolder), func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
file, err := os.Open(path)
if err != nil {
_ = file.Close()
return fmt.Errorf("failed to open imported file %s: %w", path, err)
}
hasher, err := GetHashImpl(cacheHashFormat)
if err != nil {
_ = file.Close()
return fmt.Errorf("failed to validate imported file %s: %w", path, err)
}
_, err = io.Copy(hasher, file)
if err != nil {
_ = file.Close()
return fmt.Errorf("failed to validate imported file %s: %w", path, err)
}
handle, exists := c.NewHandleFromHashes(map[string]string{
cacheHashFormat: hasher.HashToString(hasher.Sum(nil)),
})
if exists {
err = file.Close()
if err != nil {
return fmt.Errorf("failed to close imported file %s: %w", path, err)
}
err = os.Remove(path)
if err != nil {
return fmt.Errorf("failed to delete imported file %s: %w", path, err)
}
} else {
newFile, err := handle.CreateFromTemp(file)
if err != nil {
if newFile != nil {
_ = newFile.Close()
}
return fmt.Errorf("failed to rename imported file %s: %w", path, err)
}
err = newFile.Close()
if err != nil {
return fmt.Errorf("failed to close renamed imported file %s: %w", path, err)
}
_ = handle.UpdateIndex()
}
return nil
})
}
func (h *CacheIndexHandle) GetRemainingHashes(hashesToObtain []string) []string {
var remaining []string
for _, hashFormat := range hashesToObtain {
if _, ok := h.Hashes[hashFormat]; !ok {
remaining = append(remaining, hashFormat)
}
}
return remaining
}
func (h *CacheIndexHandle) Path() string {
cacheFileHash := h.Hashes[cacheHashFormat]
cacheFilePath := filepath.Join(h.index.cachePath, cacheFileHash[:2], cacheFileHash[2:])
return cacheFilePath
}
func (h *CacheIndexHandle) Open() (*os.File, error) {
return os.Open(h.Path())
}
func (h *CacheIndexHandle) CreateFromTemp(temp *os.File) (*os.File, error) {
err := temp.Close()
if err != nil {
return nil, err
}
err = os.MkdirAll(filepath.Dir(h.Path()), 0755)
if err != nil {
return nil, err
}
err = os.Rename(temp.Name(), h.Path())
if err != nil {
return nil, err
}
return os.Open(h.Path())
}
func (h *CacheIndexHandle) UpdateIndex() (warnings []error) {
// Add hashes to index
for hashFormat, hash := range h.Hashes {
hashList := h.index.Hashes[hashFormat]
if h.hashIdx >= len(hashList) {
// Add empty values to make hashList fit hashIdx
hashList = append(hashList, make([]string, (h.hashIdx-len(hashList))+1)...)
h.index.Hashes[hashFormat] = hashList
}
// Replace if it doesn't already exist
if hashList[h.hashIdx] == "" {
hashList[h.hashIdx] = h.Hashes[hashFormat]
} else if hashList[h.hashIdx] != hash {
// Warn if the existing hash is inconsistent!
warnings = append(warnings, fmt.Errorf("inconsistent %s hash for %s overwritten - value %s (expected %s)",
hashFormat, h.Path(), hashList[h.hashIdx], hash))
hashList[h.hashIdx] = h.Hashes[hashFormat]
}
}
return
}
func CreateDownloadSession(mods []*Mod, hashesToObtain []string) (DownloadSession, error) {
// Load cache index
cacheIndex := CacheIndex{Version: 1, Hashes: make(map[string][]string)}
cachePath, err := GetPackwizCache()
if err != nil {
return nil, fmt.Errorf("failed to load cache: %w", err)
}
err = os.MkdirAll(cachePath, 0755)
if err != nil {
return nil, fmt.Errorf("failed to create cache directory: %w", err)
}
err = os.MkdirAll(filepath.Join(cachePath, "temp"), 0755)
if err != nil {
return nil, fmt.Errorf("failed to create cache temp directory: %w", err)
}
cacheIndexData, err := ioutil.ReadFile(filepath.Join(cachePath, "index.json"))
if err != nil {
if !os.IsNotExist(err) {
return nil, fmt.Errorf("failed to read cache index file: %w", err)
}
} else {
err = json.Unmarshal(cacheIndexData, &cacheIndex)
if err != nil {
return nil, fmt.Errorf("failed to read cache index file: %w", err)
}
if cacheIndex.Version > 1 {
return nil, fmt.Errorf("cache index is too new (version %v)", cacheIndex.Version)
}
}
// Ensure some parts of the index are initialised
_, hasCacheHashFmt := cacheIndex.Hashes[cacheHashFormat]
if !hasCacheHashFmt {
cacheIndex.Hashes[cacheHashFormat] = make([]string, 0)
}
cacheIndex.cachePath = cachePath
cacheIndex.nextHashIdx = len(cacheIndex.Hashes[cacheHashFormat])
// Create import folder
err = os.MkdirAll(filepath.Join(cachePath, DownloadCacheImportFolder), 0755)
if err != nil {
return nil, fmt.Errorf("error creating cache import folder: %w", err)
}
// Move import files
err = cacheIndex.MoveImportFiles()
if err != nil {
return nil, fmt.Errorf("error updating cache import folder: %w", err)
}
// Create session
downloadSession := downloadSessionInternal{
cacheIndex: cacheIndex,
cacheFolder: cachePath,
hashesToObtain: hashesToObtain,
}
pendingMetadata := make(map[string][]*Mod)
// Get necessary metadata for all files
for _, mod := range mods {
if mod.Download.Mode == "url" || mod.Download.Mode == "" {
downloadSession.downloadTasks = append(downloadSession.downloadTasks, downloadTask{
mod: mod,
url: mod.Download.URL,
hashFormat: mod.Download.HashFormat,
hash: mod.Download.Hash,
})
} else if strings.HasPrefix(mod.Download.Mode, "metadata:") {
dlID := strings.TrimPrefix(mod.Download.Mode, "metadata:")
pendingMetadata[dlID] = append(pendingMetadata[dlID], mod)
} else {
return nil, fmt.Errorf("unknown download mode %s for mod %s", mod.Download.Mode, mod.Name)
}
}
for dlID, mods := range pendingMetadata {
downloader, ok := MetaDownloaders[dlID]
if !ok {
return nil, fmt.Errorf("unknown download mode %s for mod %s", mods[0].Download.Mode, mods[0].Name)
}
meta, err := downloader.GetFilesMetadata(mods)
if err != nil {
return nil, fmt.Errorf("failed to retrieve %s files: %w", dlID, err)
}
for i, v := range mods {
isManual, manualDownload := meta[i].GetManualDownload()
if isManual {
handle, err := cacheIndex.GetHandleFromHashForce(v.Download.HashFormat, v.Download.Hash)
if err != nil {
return nil, fmt.Errorf("failed to lookup manual download %s: %w", v.Name, err)
}
if handle != nil {
file, err := handle.Open()
if err != nil {
return nil, fmt.Errorf("failed to open manual download %s: %w", v.Name, err)
}
downloadSession.foundManualDownloads = append(downloadSession.foundManualDownloads, CompletedDownload{
File: file,
Mod: v,
Hashes: handle.Hashes,
})
} else {
downloadSession.manualDownloads = append(downloadSession.manualDownloads, manualDownload)
}
} else {
downloadSession.downloadTasks = append(downloadSession.downloadTasks, downloadTask{
mod: v,
metaDownloaderData: meta[i],
hashFormat: v.Download.HashFormat,
hash: v.Download.Hash,
})
}
}
}
// TODO: index housekeeping? i.e. remove deleted files, remove old files (LRU?)
// Save index after importing and Force index updates
err = downloadSession.SaveIndex()
if err != nil {
return nil, fmt.Errorf("error writing cache index: %w", err)
}
return &downloadSession, nil
}