Merge pull request #3067 from DRON-666/vss-options

Add options to fine tune VSS snapshots
This commit is contained in:
Michael Eischer 2024-04-29 18:09:47 +00:00 committed by GitHub
commit ccac7c7fb3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 744 additions and 95 deletions

View file

@ -0,0 +1,22 @@
Enhancement: Add options to configure Windows Shadow Copy Service
Restic always used 120 seconds timeout and unconditionally created VSS snapshots
for all volume mount points on disk. Now this behavior can be fine-tuned by
new options, like exclude specific volumes and mount points or completely
disable auto snapshotting of volume mount points.
For example:
restic backup --use-fs-snapshot -o vss.timeout=5m -o vss.exclude-all-mount-points=true
changes timeout to five minutes and disable snapshotting of mount points on all volumes, and
restic backup --use-fs-snapshot -o vss.exclude-volumes="d:\;c:\mnt\;\\?\Volume{e2e0315d-9066-4f97-8343-eb5659b35762}"
excludes drive `d:`, mount point `c:\mnt` and specific volume from VSS snapshotting.
restic backup --use-fs-snapshot -o vss.provider={b5946137-7b9f-4925-af80-51abd60b20d5}
uses 'Microsoft Software Shadow Copy provider 1.0' instead of the default provider.
https://github.com/restic/restic/pull/3067

View file

@ -445,7 +445,16 @@ func findParentSnapshot(ctx context.Context, repo restic.ListerLoaderUnpacked, o
} }
func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, term *termstatus.Terminal, args []string) error { func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, term *termstatus.Terminal, args []string) error {
err := opts.Check(gopts, args) var vsscfg fs.VSSConfig
var err error
if runtime.GOOS == "windows" {
if vsscfg, err = fs.ParseVSSConfig(gopts.extended); err != nil {
return err
}
}
err = opts.Check(gopts, args)
if err != nil { if err != nil {
return err return err
} }
@ -547,8 +556,8 @@ func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, ter
return err return err
} }
errorHandler := func(item string, err error) error { errorHandler := func(item string, err error) {
return progressReporter.Error(item, err) _ = progressReporter.Error(item, err)
} }
messageHandler := func(msg string, args ...interface{}) { messageHandler := func(msg string, args ...interface{}) {
@ -557,7 +566,7 @@ func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, ter
} }
} }
localVss := fs.NewLocalVss(errorHandler, messageHandler) localVss := fs.NewLocalVss(errorHandler, messageHandler, vsscfg)
defer localVss.DeleteSnapshots() defer localVss.DeleteSnapshots()
targetFS = localVss targetFS = localVss
} }

View file

@ -56,6 +56,39 @@ snapshot for each volume that contains files to backup. Files are read from the
VSS snapshot instead of the regular filesystem. This allows to backup files that are VSS snapshot instead of the regular filesystem. This allows to backup files that are
exclusively locked by another process during the backup. exclusively locked by another process during the backup.
You can use additional options to change VSS behaviour:
* ``-o vss.timeout`` specifies timeout for VSS snapshot creation, the default value is 120 seconds
* ``-o vss.exclude-all-mount-points`` disable auto snapshotting of all volume mount points
* ``-o vss.exclude-volumes`` allows excluding specific volumes or volume mount points from snapshotting
* ``-o vss.provider`` specifies VSS provider used for snapshotting
For example a 2.5 minutes timeout with snapshotting of mount points disabled can be specified as
.. code-block:: console
-o vss.timeout=2m30s -o vss.exclude-all-mount-points=true
and excluding drive ``d:\``, mount point ``c:\mnt`` and volume ``\\?\Volume{04ce0545-3391-11e0-ba2f-806e6f6e6963}\`` as
.. code-block:: console
-o vss.exclude-volumes="d:;c:\mnt\;\\?\volume{04ce0545-3391-11e0-ba2f-806e6f6e6963}"
VSS provider can be specified by GUID
.. code-block:: console
-o vss.provider={3f900f90-00e9-440e-873a-96ca5eb079e5}
or by name
.. code-block:: console
-o vss.provider="Hyper-V IC Software Shadow Copy Provider"
Also ``MS`` can be used as alias for ``Microsoft Software Shadow Copy provider 1.0``.
By default VSS ignores Outlook OST files. This is not a restriction of restic By default VSS ignores Outlook OST files. This is not a restriction of restic
but the default Windows VSS configuration. The files not to snapshot are but the default Windows VSS configuration. The files not to snapshot are
configured in the Windows registry under the following key: configured in the Windows registry under the following key:

View file

@ -3,41 +3,108 @@ package fs
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/options"
) )
// ErrorHandler is used to report errors via callback // VSSConfig holds extended options of windows volume shadow copy service.
type ErrorHandler func(item string, err error) error type VSSConfig struct {
ExcludeAllMountPoints bool `option:"exclude-all-mount-points" help:"exclude mountpoints from snapshotting on all volumes"`
ExcludeVolumes string `option:"exclude-volumes" help:"semicolon separated list of volumes to exclude from snapshotting (ex. 'c:\\;e:\\mnt;\\\\?\\Volume{...}')"`
Timeout time.Duration `option:"timeout" help:"time that the VSS can spend creating snapshot before timing out"`
Provider string `option:"provider" help:"VSS provider identifier which will be used for snapshotting"`
}
func init() {
if runtime.GOOS == "windows" {
options.Register("vss", VSSConfig{})
}
}
// NewVSSConfig returns a new VSSConfig with the default values filled in.
func NewVSSConfig() VSSConfig {
return VSSConfig{
Timeout: time.Second * 120,
}
}
// ParseVSSConfig parses a VSS extended options to VSSConfig struct.
func ParseVSSConfig(o options.Options) (VSSConfig, error) {
cfg := NewVSSConfig()
o = o.Extract("vss")
if err := o.Apply("vss", &cfg); err != nil {
return VSSConfig{}, err
}
return cfg, nil
}
// ErrorHandler is used to report errors via callback.
type ErrorHandler func(item string, err error)
// MessageHandler is used to report errors/messages via callbacks. // MessageHandler is used to report errors/messages via callbacks.
type MessageHandler func(msg string, args ...interface{}) type MessageHandler func(msg string, args ...interface{})
// VolumeFilter is used to filter volumes by it's mount point or GUID path.
type VolumeFilter func(volume string) bool
// LocalVss is a wrapper around the local file system which uses windows volume // LocalVss is a wrapper around the local file system which uses windows volume
// shadow copy service (VSS) in a transparent way. // shadow copy service (VSS) in a transparent way.
type LocalVss struct { type LocalVss struct {
FS FS
snapshots map[string]VssSnapshot snapshots map[string]VssSnapshot
failedSnapshots map[string]struct{} failedSnapshots map[string]struct{}
mutex sync.RWMutex mutex sync.RWMutex
msgError ErrorHandler msgError ErrorHandler
msgMessage MessageHandler msgMessage MessageHandler
excludeAllMountPoints bool
excludeVolumes map[string]struct{}
timeout time.Duration
provider string
} }
// statically ensure that LocalVss implements FS. // statically ensure that LocalVss implements FS.
var _ FS = &LocalVss{} var _ FS = &LocalVss{}
// parseMountPoints try to convert semicolon separated list of mount points
// to map of lowercased volume GUID pathes. Mountpoints already in volume
// GUID path format will be validated and normalized.
func parseMountPoints(list string, msgError ErrorHandler) (volumes map[string]struct{}) {
if list == "" {
return
}
for _, s := range strings.Split(list, ";") {
if v, err := GetVolumeNameForVolumeMountPoint(s); err != nil {
msgError(s, errors.Errorf("failed to parse vss.exclude-volumes [%s]: %s", s, err))
} else {
if volumes == nil {
volumes = make(map[string]struct{})
}
volumes[strings.ToLower(v)] = struct{}{}
}
}
return
}
// NewLocalVss creates a new wrapper around the windows filesystem using volume // NewLocalVss creates a new wrapper around the windows filesystem using volume
// shadow copy service to access locked files. // shadow copy service to access locked files.
func NewLocalVss(msgError ErrorHandler, msgMessage MessageHandler) *LocalVss { func NewLocalVss(msgError ErrorHandler, msgMessage MessageHandler, cfg VSSConfig) *LocalVss {
return &LocalVss{ return &LocalVss{
FS: Local{}, FS: Local{},
snapshots: make(map[string]VssSnapshot), snapshots: make(map[string]VssSnapshot),
failedSnapshots: make(map[string]struct{}), failedSnapshots: make(map[string]struct{}),
msgError: msgError, msgError: msgError,
msgMessage: msgMessage, msgMessage: msgMessage,
excludeAllMountPoints: cfg.ExcludeAllMountPoints,
excludeVolumes: parseMountPoints(cfg.ExcludeVolumes, msgError),
timeout: cfg.Timeout,
provider: cfg.Provider,
} }
} }
@ -50,7 +117,7 @@ func (fs *LocalVss) DeleteSnapshots() {
for volumeName, snapshot := range fs.snapshots { for volumeName, snapshot := range fs.snapshots {
if err := snapshot.Delete(); err != nil { if err := snapshot.Delete(); err != nil {
_ = fs.msgError(volumeName, errors.Errorf("failed to delete VSS snapshot: %s", err)) fs.msgError(volumeName, errors.Errorf("failed to delete VSS snapshot: %s", err))
activeSnapshots[volumeName] = snapshot activeSnapshots[volumeName] = snapshot
} }
} }
@ -78,12 +145,27 @@ func (fs *LocalVss) Lstat(name string) (os.FileInfo, error) {
return os.Lstat(fs.snapshotPath(name)) return os.Lstat(fs.snapshotPath(name))
} }
// isMountPointIncluded is true if given mountpoint included by user.
func (fs *LocalVss) isMountPointIncluded(mountPoint string) bool {
if fs.excludeVolumes == nil {
return true
}
volume, err := GetVolumeNameForVolumeMountPoint(mountPoint)
if err != nil {
fs.msgError(mountPoint, errors.Errorf("failed to get volume from mount point [%s]: %s", mountPoint, err))
return true
}
_, ok := fs.excludeVolumes[strings.ToLower(volume)]
return !ok
}
// snapshotPath returns the path inside a VSS snapshots if it already exists. // snapshotPath returns the path inside a VSS snapshots if it already exists.
// If the path is not yet available as a snapshot, a snapshot is created. // If the path is not yet available as a snapshot, a snapshot is created.
// If creation of a snapshot fails the file's original path is returned as // If creation of a snapshot fails the file's original path is returned as
// a fallback. // a fallback.
func (fs *LocalVss) snapshotPath(path string) string { func (fs *LocalVss) snapshotPath(path string) string {
fixPath := fixpath(path) fixPath := fixpath(path)
if strings.HasPrefix(fixPath, `\\?\UNC\`) { if strings.HasPrefix(fixPath, `\\?\UNC\`) {
@ -114,23 +196,36 @@ func (fs *LocalVss) snapshotPath(path string) string {
if !snapshotExists && !snapshotFailed { if !snapshotExists && !snapshotFailed {
vssVolume := volumeNameLower + string(filepath.Separator) vssVolume := volumeNameLower + string(filepath.Separator)
fs.msgMessage("creating VSS snapshot for [%s]\n", vssVolume)
if snapshot, err := NewVssSnapshot(vssVolume, 120, fs.msgError); err != nil { if !fs.isMountPointIncluded(vssVolume) {
_ = fs.msgError(vssVolume, errors.Errorf("failed to create snapshot for [%s]: %s", fs.msgMessage("snapshots for [%s] excluded by user\n", vssVolume)
vssVolume, err))
fs.failedSnapshots[volumeNameLower] = struct{}{} fs.failedSnapshots[volumeNameLower] = struct{}{}
} else { } else {
fs.snapshots[volumeNameLower] = snapshot fs.msgMessage("creating VSS snapshot for [%s]\n", vssVolume)
fs.msgMessage("successfully created snapshot for [%s]\n", vssVolume)
if len(snapshot.mountPointInfo) > 0 { var includeVolume VolumeFilter
fs.msgMessage("mountpoints in snapshot volume [%s]:\n", vssVolume) if !fs.excludeAllMountPoints {
for mp, mpInfo := range snapshot.mountPointInfo { includeVolume = func(volume string) bool {
info := "" return fs.isMountPointIncluded(volume)
if !mpInfo.IsSnapshotted() { }
info = " (not snapshotted)" }
if snapshot, err := NewVssSnapshot(fs.provider, vssVolume, fs.timeout, includeVolume, fs.msgError); err != nil {
fs.msgError(vssVolume, errors.Errorf("failed to create snapshot for [%s]: %s",
vssVolume, err))
fs.failedSnapshots[volumeNameLower] = struct{}{}
} else {
fs.snapshots[volumeNameLower] = snapshot
fs.msgMessage("successfully created snapshot for [%s]\n", vssVolume)
if len(snapshot.mountPointInfo) > 0 {
fs.msgMessage("mountpoints in snapshot volume [%s]:\n", vssVolume)
for mp, mpInfo := range snapshot.mountPointInfo {
info := ""
if !mpInfo.IsSnapshotted() {
info = " (not snapshotted)"
}
fs.msgMessage(" - %s%s\n", mp, info)
} }
fs.msgMessage(" - %s%s\n", mp, info)
} }
} }
} }
@ -173,9 +268,8 @@ func (fs *LocalVss) snapshotPath(path string) string {
snapshotPath = fs.Join(snapshot.GetSnapshotDeviceObject(), snapshotPath = fs.Join(snapshot.GetSnapshotDeviceObject(),
strings.TrimPrefix(fixPath, volumeName)) strings.TrimPrefix(fixPath, volumeName))
if snapshotPath == snapshot.GetSnapshotDeviceObject() { if snapshotPath == snapshot.GetSnapshotDeviceObject() {
snapshotPath = snapshotPath + string(filepath.Separator) snapshotPath += string(filepath.Separator)
} }
} else { } else {
// no snapshot is available for the requested path: // no snapshot is available for the requested path:
// -> try to backup without a snapshot // -> try to backup without a snapshot

View file

@ -0,0 +1,285 @@
// +build windows
package fs
import (
"fmt"
"regexp"
"strings"
"testing"
"time"
ole "github.com/go-ole/go-ole"
"github.com/restic/restic/internal/options"
)
func matchStrings(ptrs []string, strs []string) bool {
if len(ptrs) != len(strs) {
return false
}
for i, p := range ptrs {
if p == "" {
return false
}
matched, err := regexp.MatchString(p, strs[i])
if err != nil {
panic(err)
}
if !matched {
return false
}
}
return true
}
func matchMap(strs []string, m map[string]struct{}) bool {
if len(strs) != len(m) {
return false
}
for _, s := range strs {
if _, ok := m[s]; !ok {
return false
}
}
return true
}
func TestVSSConfig(t *testing.T) {
type config struct {
excludeAllMountPoints bool
timeout time.Duration
provider string
}
setTests := []struct {
input options.Options
output config
}{
{
options.Options{
"vss.timeout": "6h38m42s",
"vss.provider": "Ms",
},
config{
timeout: 23922000000000,
provider: "Ms",
},
},
{
options.Options{
"vss.exclude-all-mount-points": "t",
"vss.provider": "{b5946137-7b9f-4925-af80-51abd60b20d5}",
},
config{
excludeAllMountPoints: true,
timeout: 120000000000,
provider: "{b5946137-7b9f-4925-af80-51abd60b20d5}",
},
},
{
options.Options{
"vss.exclude-all-mount-points": "0",
"vss.exclude-volumes": "",
"vss.timeout": "120s",
"vss.provider": "Microsoft Software Shadow Copy provider 1.0",
},
config{
timeout: 120000000000,
provider: "Microsoft Software Shadow Copy provider 1.0",
},
},
}
for i, test := range setTests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
cfg, err := ParseVSSConfig(test.input)
if err != nil {
t.Fatal(err)
}
errorHandler := func(item string, err error) {
t.Fatalf("unexpected error (%v)", err)
}
messageHandler := func(msg string, args ...interface{}) {
t.Fatalf("unexpected message (%s)", fmt.Sprintf(msg, args))
}
dst := NewLocalVss(errorHandler, messageHandler, cfg)
if dst.excludeAllMountPoints != test.output.excludeAllMountPoints ||
dst.excludeVolumes != nil || dst.timeout != test.output.timeout ||
dst.provider != test.output.provider {
t.Fatalf("wrong result, want:\n %#v\ngot:\n %#v", test.output, dst)
}
})
}
}
func TestParseMountPoints(t *testing.T) {
volumeMatch := regexp.MustCompile(`^\\\\\?\\Volume\{[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}\}\\$`)
// It's not a good idea to test functions based on GetVolumeNameForVolumeMountPoint by calling
// GetVolumeNameForVolumeMountPoint itself, but we have restricted test environment:
// cannot manage volumes and can only be sure that the mount point C:\ exists
sysVolume, err := GetVolumeNameForVolumeMountPoint("C:")
if err != nil {
t.Fatal(err)
}
// We don't know a valid volume GUID path for c:\, but we'll at least check its format
if !volumeMatch.MatchString(sysVolume) {
t.Fatalf("invalid volume GUID path: %s", sysVolume)
}
// Changing the case and removing trailing backslash allows tests
// the equality of different ways of writing a volume name
sysVolumeMutated := strings.ToUpper(sysVolume[:len(sysVolume)-1])
sysVolumeMatch := strings.ToLower(sysVolume)
type check struct {
volume string
result bool
}
setTests := []struct {
input options.Options
output []string
checks []check
errors []string
}{
{
options.Options{
"vss.exclude-volumes": `c:;c:\;` + sysVolume + `;` + sysVolumeMutated,
},
[]string{
sysVolumeMatch,
},
[]check{
{`c:\`, false},
{`c:`, false},
{sysVolume, false},
{sysVolumeMutated, false},
},
[]string{},
},
{
options.Options{
"vss.exclude-volumes": `z:\nonexistent;c:;c:\windows\;\\?\Volume{39b9cac2-bcdb-4d51-97c8-0d0677d607fb}\`,
},
[]string{
sysVolumeMatch,
},
[]check{
{`c:\windows\`, true},
{`\\?\Volume{39b9cac2-bcdb-4d51-97c8-0d0677d607fb}\`, true},
{`c:`, false},
{``, true},
},
[]string{
`failed to parse vss\.exclude-volumes \[z:\\nonexistent\]:.*`,
`failed to parse vss\.exclude-volumes \[c:\\windows\\\]:.*`,
`failed to parse vss\.exclude-volumes \[\\\\\?\\Volume\{39b9cac2-bcdb-4d51-97c8-0d0677d607fb\}\\\]:.*`,
`failed to get volume from mount point \[c:\\windows\\\]:.*`,
`failed to get volume from mount point \[\\\\\?\\Volume\{39b9cac2-bcdb-4d51-97c8-0d0677d607fb\}\\\]:.*`,
`failed to get volume from mount point \[\]:.*`,
},
},
}
for i, test := range setTests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
cfg, err := ParseVSSConfig(test.input)
if err != nil {
t.Fatal(err)
}
var log []string
errorHandler := func(item string, err error) {
log = append(log, strings.TrimSpace(err.Error()))
}
messageHandler := func(msg string, args ...interface{}) {
t.Fatalf("unexpected message (%s)", fmt.Sprintf(msg, args))
}
dst := NewLocalVss(errorHandler, messageHandler, cfg)
if !matchMap(test.output, dst.excludeVolumes) {
t.Fatalf("wrong result, want:\n %#v\ngot:\n %#v",
test.output, dst.excludeVolumes)
}
for _, c := range test.checks {
if dst.isMountPointIncluded(c.volume) != c.result {
t.Fatalf(`wrong check: isMountPointIncluded("%s") != %v`, c.volume, c.result)
}
}
if !matchStrings(test.errors, log) {
t.Fatalf("wrong log, want:\n %#v\ngot:\n %#v", test.errors, log)
}
})
}
}
func TestParseProvider(t *testing.T) {
msProvider := ole.NewGUID("{b5946137-7b9f-4925-af80-51abd60b20d5}")
setTests := []struct {
provider string
id *ole.GUID
result string
}{
{
"",
ole.IID_NULL,
"",
},
{
"mS",
msProvider,
"",
},
{
"{B5946137-7b9f-4925-Af80-51abD60b20d5}",
msProvider,
"",
},
{
"Microsoft Software Shadow Copy provider 1.0",
msProvider,
"",
},
{
"{04560982-3d7d-4bbc-84f7-0712f833a28f}",
nil,
`invalid VSS provider "{04560982-3d7d-4bbc-84f7-0712f833a28f}"`,
},
{
"non-existent provider",
nil,
`invalid VSS provider "non-existent provider"`,
},
}
_ = ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
for i, test := range setTests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
id, err := getProviderID(test.provider)
if err != nil && id != nil {
t.Fatalf("err!=nil but id=%v", id)
}
if test.result != "" || err != nil {
var result string
if err != nil {
result = err.Error()
}
if test.result != result || test.result == "" {
t.Fatalf("wrong result, want:\n %#v\ngot:\n %#v", test.result, result)
}
} else if !ole.IsEqualGUID(id, test.id) {
t.Fatalf("wrong id, want:\n %s\ngot:\n %s", test.id.String(), id.String())
}
})
}
}

View file

@ -4,6 +4,8 @@
package fs package fs
import ( import (
"time"
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
) )
@ -31,10 +33,16 @@ func HasSufficientPrivilegesForVSS() error {
return errors.New("VSS snapshots are only supported on windows") return errors.New("VSS snapshots are only supported on windows")
} }
// GetVolumeNameForVolumeMountPoint add trailing backslash to input parameter
// and calls the equivalent windows api.
func GetVolumeNameForVolumeMountPoint(mountPoint string) (string, error) {
return mountPoint, nil
}
// NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't // NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't
// finish within the timeout an error is returned. // finish within the timeout an error is returned.
func NewVssSnapshot( func NewVssSnapshot(_ string,
_ string, _ uint, _ ErrorHandler) (VssSnapshot, error) { _ string, _ time.Duration, _ VolumeFilter, _ ErrorHandler) (VssSnapshot, error) {
return VssSnapshot{}, errors.New("VSS snapshots are only supported on windows") return VssSnapshot{}, errors.New("VSS snapshots are only supported on windows")
} }

View file

@ -5,10 +5,12 @@ package fs
import ( import (
"fmt" "fmt"
"math"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
ole "github.com/go-ole/go-ole" ole "github.com/go-ole/go-ole"
@ -20,8 +22,10 @@ import (
type HRESULT uint type HRESULT uint
// HRESULT constant values necessary for using VSS api. // HRESULT constant values necessary for using VSS api.
//nolint:golint
const ( const (
S_OK HRESULT = 0x00000000 S_OK HRESULT = 0x00000000
S_FALSE HRESULT = 0x00000001
E_ACCESSDENIED HRESULT = 0x80070005 E_ACCESSDENIED HRESULT = 0x80070005
E_OUTOFMEMORY HRESULT = 0x8007000E E_OUTOFMEMORY HRESULT = 0x8007000E
E_INVALIDARG HRESULT = 0x80070057 E_INVALIDARG HRESULT = 0x80070057
@ -255,6 +259,7 @@ type IVssBackupComponents struct {
} }
// IVssBackupComponentsVTable is the vtable for IVssBackupComponents. // IVssBackupComponentsVTable is the vtable for IVssBackupComponents.
// nolint:structcheck
type IVssBackupComponentsVTable struct { type IVssBackupComponentsVTable struct {
ole.IUnknownVtbl ole.IUnknownVtbl
getWriterComponentsCount uintptr getWriterComponentsCount uintptr
@ -364,7 +369,7 @@ func (vss *IVssBackupComponents) convertToVSSAsync(
} }
// IsVolumeSupported calls the equivalent VSS api. // IsVolumeSupported calls the equivalent VSS api.
func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, error) { func (vss *IVssBackupComponents) IsVolumeSupported(providerID *ole.GUID, volumeName string) (bool, error) {
volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName) volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName)
if err != nil { if err != nil {
panic(err) panic(err)
@ -374,7 +379,7 @@ func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, err
var result uintptr var result uintptr
if runtime.GOARCH == "386" { if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(ole.IID_NULL)) id := (*[4]uintptr)(unsafe.Pointer(providerID))
result, _, _ = syscall.Syscall9(vss.getVTable().isVolumeSupported, 7, result, _, _ = syscall.Syscall9(vss.getVTable().isVolumeSupported, 7,
uintptr(unsafe.Pointer(vss)), id[0], id[1], id[2], id[3], uintptr(unsafe.Pointer(vss)), id[0], id[1], id[2], id[3],
@ -382,7 +387,7 @@ func (vss *IVssBackupComponents) IsVolumeSupported(volumeName string) (bool, err
0) 0)
} else { } else {
result, _, _ = syscall.Syscall6(vss.getVTable().isVolumeSupported, 4, result, _, _ = syscall.Syscall6(vss.getVTable().isVolumeSupported, 4,
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(ole.IID_NULL)), uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(providerID)),
uintptr(unsafe.Pointer(volumeNamePointer)), uintptr(unsafe.Pointer(&isSupportedRaw)), 0, uintptr(unsafe.Pointer(volumeNamePointer)), uintptr(unsafe.Pointer(&isSupportedRaw)), 0,
0) 0)
} }
@ -408,24 +413,24 @@ func (vss *IVssBackupComponents) StartSnapshotSet() (ole.GUID, error) {
} }
// AddToSnapshotSet calls the equivalent VSS api. // AddToSnapshotSet calls the equivalent VSS api.
func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, idSnapshot *ole.GUID) error { func (vss *IVssBackupComponents) AddToSnapshotSet(volumeName string, providerID *ole.GUID, idSnapshot *ole.GUID) error {
volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName) volumeNamePointer, err := syscall.UTF16PtrFromString(volumeName)
if err != nil { if err != nil {
panic(err) panic(err)
} }
var result uintptr = 0 var result uintptr
if runtime.GOARCH == "386" { if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(ole.IID_NULL)) id := (*[4]uintptr)(unsafe.Pointer(providerID))
result, _, _ = syscall.Syscall9(vss.getVTable().addToSnapshotSet, 7, result, _, _ = syscall.Syscall9(vss.getVTable().addToSnapshotSet, 7,
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)), id[0], id[1], uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)),
id[2], id[3], uintptr(unsafe.Pointer(idSnapshot)), 0, 0) id[0], id[1], id[2], id[3], uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
} else { } else {
result, _, _ = syscall.Syscall6(vss.getVTable().addToSnapshotSet, 4, result, _, _ = syscall.Syscall6(vss.getVTable().addToSnapshotSet, 4,
uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)), uintptr(unsafe.Pointer(vss)), uintptr(unsafe.Pointer(volumeNamePointer)),
uintptr(unsafe.Pointer(ole.IID_NULL)), uintptr(unsafe.Pointer(idSnapshot)), 0, 0) uintptr(unsafe.Pointer(providerID)), uintptr(unsafe.Pointer(idSnapshot)), 0, 0)
} }
return newVssErrorIfResultNotOK("AddToSnapshotSet() failed", HRESULT(result)) return newVssErrorIfResultNotOK("AddToSnapshotSet() failed", HRESULT(result))
@ -478,9 +483,9 @@ func (vss *IVssBackupComponents) DoSnapshotSet() (*IVSSAsync, error) {
// DeleteSnapshots calls the equivalent VSS api. // DeleteSnapshots calls the equivalent VSS api.
func (vss *IVssBackupComponents) DeleteSnapshots(snapshotID ole.GUID) (int32, ole.GUID, error) { func (vss *IVssBackupComponents) DeleteSnapshots(snapshotID ole.GUID) (int32, ole.GUID, error) {
var deletedSnapshots int32 = 0 var deletedSnapshots int32
var nondeletedSnapshotID ole.GUID var nondeletedSnapshotID ole.GUID
var result uintptr = 0 var result uintptr
if runtime.GOARCH == "386" { if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(&snapshotID)) id := (*[4]uintptr)(unsafe.Pointer(&snapshotID))
@ -504,7 +509,7 @@ func (vss *IVssBackupComponents) DeleteSnapshots(snapshotID ole.GUID) (int32, ol
// GetSnapshotProperties calls the equivalent VSS api. // GetSnapshotProperties calls the equivalent VSS api.
func (vss *IVssBackupComponents) GetSnapshotProperties(snapshotID ole.GUID, func (vss *IVssBackupComponents) GetSnapshotProperties(snapshotID ole.GUID,
properties *VssSnapshotProperties) error { properties *VssSnapshotProperties) error {
var result uintptr = 0 var result uintptr
if runtime.GOARCH == "386" { if runtime.GOARCH == "386" {
id := (*[4]uintptr)(unsafe.Pointer(&snapshotID)) id := (*[4]uintptr)(unsafe.Pointer(&snapshotID))
@ -527,8 +532,8 @@ func vssFreeSnapshotProperties(properties *VssSnapshotProperties) error {
if err != nil { if err != nil {
return err return err
} }
// this function always succeeds and returns no value
proc.Call(uintptr(unsafe.Pointer(properties))) _, _, _ = proc.Call(uintptr(unsafe.Pointer(properties)))
return nil return nil
} }
@ -543,6 +548,7 @@ func (vss *IVssBackupComponents) BackupComplete() (*IVSSAsync, error) {
} }
// VssSnapshotProperties defines the properties of a VSS snapshot as part of the VSS api. // VssSnapshotProperties defines the properties of a VSS snapshot as part of the VSS api.
// nolint:structcheck
type VssSnapshotProperties struct { type VssSnapshotProperties struct {
snapshotID ole.GUID snapshotID ole.GUID
snapshotSetID ole.GUID snapshotSetID ole.GUID
@ -559,6 +565,24 @@ type VssSnapshotProperties struct {
status uint status uint
} }
// VssProviderProperties defines the properties of a VSS provider as part of the VSS api.
// nolint:structcheck
type VssProviderProperties struct {
providerID ole.GUID
providerName *uint16
providerType uint32
providerVersion *uint16
providerVersionID ole.GUID
classID ole.GUID
}
func vssFreeProviderProperties(p *VssProviderProperties) {
ole.CoTaskMemFree(uintptr(unsafe.Pointer(p.providerName)))
p.providerName = nil
ole.CoTaskMemFree(uintptr(unsafe.Pointer(p.providerVersion)))
p.providerVersion = nil
}
// GetSnapshotDeviceObject returns root path to access the snapshot files // GetSnapshotDeviceObject returns root path to access the snapshot files
// and folders. // and folders.
func (p *VssSnapshotProperties) GetSnapshotDeviceObject() string { func (p *VssSnapshotProperties) GetSnapshotDeviceObject() string {
@ -617,8 +641,13 @@ func (vssAsync *IVSSAsync) QueryStatus() (HRESULT, uint32) {
// WaitUntilAsyncFinished waits until either the async call is finished or // WaitUntilAsyncFinished waits until either the async call is finished or
// the given timeout is reached. // the given timeout is reached.
func (vssAsync *IVSSAsync) WaitUntilAsyncFinished(millis uint32) error { func (vssAsync *IVSSAsync) WaitUntilAsyncFinished(timeout time.Duration) error {
hresult := vssAsync.Wait(millis) const maxTimeout = math.MaxInt32 * time.Millisecond
if timeout > maxTimeout {
timeout = maxTimeout
}
hresult := vssAsync.Wait(uint32(timeout.Milliseconds()))
err := newVssErrorIfResultNotOK("Wait() failed", hresult) err := newVssErrorIfResultNotOK("Wait() failed", hresult)
if err != nil { if err != nil {
vssAsync.Cancel() vssAsync.Cancel()
@ -651,6 +680,75 @@ func (vssAsync *IVSSAsync) WaitUntilAsyncFinished(millis uint32) error {
return nil return nil
} }
// UIID_IVSS_ADMIN defines the GUID of IVSSAdmin.
var (
UIID_IVSS_ADMIN = ole.NewGUID("{77ED5996-2F63-11d3-8A39-00C04F72D8E3}")
CLSID_VSS_COORDINATOR = ole.NewGUID("{E579AB5F-1CC4-44b4-BED9-DE0991FF0623}")
)
// IVSSAdmin VSS api interface.
type IVSSAdmin struct {
ole.IUnknown
}
// IVSSAdminVTable is the vtable for IVSSAdmin.
// nolint:structcheck
type IVSSAdminVTable struct {
ole.IUnknownVtbl
registerProvider uintptr
unregisterProvider uintptr
queryProviders uintptr
abortAllSnapshotsInProgress uintptr
}
// getVTable returns the vtable for IVSSAdmin.
func (vssAdmin *IVSSAdmin) getVTable() *IVSSAdminVTable {
return (*IVSSAdminVTable)(unsafe.Pointer(vssAdmin.RawVTable))
}
// QueryProviders calls the equivalent VSS api.
func (vssAdmin *IVSSAdmin) QueryProviders() (*IVssEnumObject, error) {
var enum *IVssEnumObject
result, _, _ := syscall.Syscall(vssAdmin.getVTable().queryProviders, 2,
uintptr(unsafe.Pointer(vssAdmin)), uintptr(unsafe.Pointer(&enum)), 0)
return enum, newVssErrorIfResultNotOK("QueryProviders() failed", HRESULT(result))
}
// IVssEnumObject VSS api interface.
type IVssEnumObject struct {
ole.IUnknown
}
// IVssEnumObjectVTable is the vtable for IVssEnumObject.
// nolint:structcheck
type IVssEnumObjectVTable struct {
ole.IUnknownVtbl
next uintptr
skip uintptr
reset uintptr
clone uintptr
}
// getVTable returns the vtable for IVssEnumObject.
func (vssEnum *IVssEnumObject) getVTable() *IVssEnumObjectVTable {
return (*IVssEnumObjectVTable)(unsafe.Pointer(vssEnum.RawVTable))
}
// Next calls the equivalent VSS api.
func (vssEnum *IVssEnumObject) Next(count uint, props unsafe.Pointer) (uint, error) {
var fetched uint32
result, _, _ := syscall.Syscall6(vssEnum.getVTable().next, 4,
uintptr(unsafe.Pointer(vssEnum)), uintptr(count), uintptr(props),
uintptr(unsafe.Pointer(&fetched)), 0, 0)
if HRESULT(result) == S_FALSE {
return uint(fetched), nil
}
return uint(fetched), newVssErrorIfResultNotOK("Next() failed", HRESULT(result))
}
// MountPoint wraps all information of a snapshot of a mountpoint on a volume. // MountPoint wraps all information of a snapshot of a mountpoint on a volume.
type MountPoint struct { type MountPoint struct {
isSnapshotted bool isSnapshotted bool
@ -677,7 +775,7 @@ type VssSnapshot struct {
snapshotProperties VssSnapshotProperties snapshotProperties VssSnapshotProperties
snapshotDeviceObject string snapshotDeviceObject string
mountPointInfo map[string]MountPoint mountPointInfo map[string]MountPoint
timeoutInMillis uint32 timeout time.Duration
} }
// GetSnapshotDeviceObject returns root path to access the snapshot files // GetSnapshotDeviceObject returns root path to access the snapshot files
@ -694,7 +792,12 @@ func initializeVssCOMInterface() (*ole.IUnknown, error) {
} }
// ensure COM is initialized before use // ensure COM is initialized before use
ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED) if err = ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil {
// CoInitializeEx returns S_FALSE if COM is already initialized
if oleErr, ok := err.(*ole.OleError); !ok || HRESULT(oleErr.Code()) != S_FALSE {
return nil, err
}
}
var oleIUnknown *ole.IUnknown var oleIUnknown *ole.IUnknown
result, _, _ := vssInstance.Call(uintptr(unsafe.Pointer(&oleIUnknown))) result, _, _ := vssInstance.Call(uintptr(unsafe.Pointer(&oleIUnknown)))
@ -727,12 +830,34 @@ func HasSufficientPrivilegesForVSS() error {
return err return err
} }
// GetVolumeNameForVolumeMountPoint add trailing backslash to input parameter
// and calls the equivalent windows api.
func GetVolumeNameForVolumeMountPoint(mountPoint string) (string, error) {
if mountPoint != "" && mountPoint[len(mountPoint)-1] != filepath.Separator {
mountPoint += string(filepath.Separator)
}
mountPointPointer, err := syscall.UTF16PtrFromString(mountPoint)
if err != nil {
return mountPoint, err
}
// A reasonable size for the buffer to accommodate the largest possible
// volume GUID path is 50 characters.
volumeNameBuffer := make([]uint16, 50)
if err := windows.GetVolumeNameForVolumeMountPoint(
mountPointPointer, &volumeNameBuffer[0], 50); err != nil {
return mountPoint, err
}
return syscall.UTF16ToString(volumeNameBuffer), nil
}
// NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't // NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't
// finish within the timeout an error is returned. // finish within the timeout an error is returned.
func NewVssSnapshot( func NewVssSnapshot(provider string,
volume string, timeoutInSeconds uint, msgError ErrorHandler) (VssSnapshot, error) { volume string, timeout time.Duration, filter VolumeFilter, msgError ErrorHandler) (VssSnapshot, error) {
is64Bit, err := isRunningOn64BitWindows() is64Bit, err := isRunningOn64BitWindows()
if err != nil { if err != nil {
return VssSnapshot{}, newVssTextError(fmt.Sprintf( return VssSnapshot{}, newVssTextError(fmt.Sprintf(
"Failed to detect windows architecture: %s", err.Error())) "Failed to detect windows architecture: %s", err.Error()))
@ -744,7 +869,7 @@ func NewVssSnapshot(
runtime.GOARCH)) runtime.GOARCH))
} }
timeoutInMillis := uint32(timeoutInSeconds * 1000) deadline := time.Now().Add(timeout)
oleIUnknown, err := initializeVssCOMInterface() oleIUnknown, err := initializeVssCOMInterface()
if oleIUnknown != nil { if oleIUnknown != nil {
@ -778,6 +903,12 @@ func NewVssSnapshot(
iVssBackupComponents := (*IVssBackupComponents)(unsafe.Pointer(comInterface)) iVssBackupComponents := (*IVssBackupComponents)(unsafe.Pointer(comInterface))
providerID, err := getProviderID(provider)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
if err := iVssBackupComponents.InitializeForBackup(); err != nil { if err := iVssBackupComponents.InitializeForBackup(); err != nil {
iVssBackupComponents.Release() iVssBackupComponents.Release()
return VssSnapshot{}, err return VssSnapshot{}, err
@ -796,13 +927,13 @@ func NewVssSnapshot(
} }
err = callAsyncFunctionAndWait(iVssBackupComponents.GatherWriterMetadata, err = callAsyncFunctionAndWait(iVssBackupComponents.GatherWriterMetadata,
"GatherWriterMetadata", timeoutInMillis) "GatherWriterMetadata", deadline)
if err != nil { if err != nil {
iVssBackupComponents.Release() iVssBackupComponents.Release()
return VssSnapshot{}, err return VssSnapshot{}, err
} }
if isSupported, err := iVssBackupComponents.IsVolumeSupported(volume); err != nil { if isSupported, err := iVssBackupComponents.IsVolumeSupported(providerID, volume); err != nil {
iVssBackupComponents.Release() iVssBackupComponents.Release()
return VssSnapshot{}, err return VssSnapshot{}, err
} else if !isSupported { } else if !isSupported {
@ -817,44 +948,53 @@ func NewVssSnapshot(
return VssSnapshot{}, err return VssSnapshot{}, err
} }
if err := iVssBackupComponents.AddToSnapshotSet(volume, &snapshotSetID); err != nil { if err := iVssBackupComponents.AddToSnapshotSet(volume, providerID, &snapshotSetID); err != nil {
iVssBackupComponents.Release() iVssBackupComponents.Release()
return VssSnapshot{}, err return VssSnapshot{}, err
} }
mountPoints, err := enumerateMountedFolders(volume)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, newVssTextError(fmt.Sprintf(
"failed to enumerate mount points for volume %s: %s", volume, err))
}
mountPointInfo := make(map[string]MountPoint) mountPointInfo := make(map[string]MountPoint)
for _, mountPoint := range mountPoints { // if filter==nil just don't process mount points for this volume at all
// ensure every mountpoint is available even without a valid if filter != nil {
// snapshot because we need to consider this when backing up files mountPoints, err := enumerateMountedFolders(volume)
mountPointInfo[mountPoint] = MountPoint{isSnapshotted: false}
if isSupported, err := iVssBackupComponents.IsVolumeSupported(mountPoint); err != nil {
continue
} else if !isSupported {
continue
}
var mountPointSnapshotSetID ole.GUID
err := iVssBackupComponents.AddToSnapshotSet(mountPoint, &mountPointSnapshotSetID)
if err != nil { if err != nil {
iVssBackupComponents.Release() iVssBackupComponents.Release()
return VssSnapshot{}, err
return VssSnapshot{}, newVssTextError(fmt.Sprintf(
"failed to enumerate mount points for volume %s: %s", volume, err))
} }
mountPointInfo[mountPoint] = MountPoint{isSnapshotted: true, for _, mountPoint := range mountPoints {
snapshotSetID: mountPointSnapshotSetID} // ensure every mountpoint is available even without a valid
// snapshot because we need to consider this when backing up files
mountPointInfo[mountPoint] = MountPoint{isSnapshotted: false}
if !filter(mountPoint) {
continue
} else if isSupported, err := iVssBackupComponents.IsVolumeSupported(providerID, mountPoint); err != nil {
continue
} else if !isSupported {
continue
}
var mountPointSnapshotSetID ole.GUID
err := iVssBackupComponents.AddToSnapshotSet(mountPoint, providerID, &mountPointSnapshotSetID)
if err != nil {
iVssBackupComponents.Release()
return VssSnapshot{}, err
}
mountPointInfo[mountPoint] = MountPoint{
isSnapshotted: true,
snapshotSetID: mountPointSnapshotSetID,
}
}
} }
err = callAsyncFunctionAndWait(iVssBackupComponents.PrepareForBackup, "PrepareForBackup", err = callAsyncFunctionAndWait(iVssBackupComponents.PrepareForBackup, "PrepareForBackup",
timeoutInMillis) deadline)
if err != nil { if err != nil {
// After calling PrepareForBackup one needs to call AbortBackup() before releasing the VSS // After calling PrepareForBackup one needs to call AbortBackup() before releasing the VSS
// instance for proper cleanup. // instance for proper cleanup.
@ -865,9 +1005,9 @@ func NewVssSnapshot(
} }
err = callAsyncFunctionAndWait(iVssBackupComponents.DoSnapshotSet, "DoSnapshotSet", err = callAsyncFunctionAndWait(iVssBackupComponents.DoSnapshotSet, "DoSnapshotSet",
timeoutInMillis) deadline)
if err != nil { if err != nil {
iVssBackupComponents.AbortBackup() _ = iVssBackupComponents.AbortBackup()
iVssBackupComponents.Release() iVssBackupComponents.Release()
return VssSnapshot{}, err return VssSnapshot{}, err
} }
@ -875,13 +1015,12 @@ func NewVssSnapshot(
var snapshotProperties VssSnapshotProperties var snapshotProperties VssSnapshotProperties
err = iVssBackupComponents.GetSnapshotProperties(snapshotSetID, &snapshotProperties) err = iVssBackupComponents.GetSnapshotProperties(snapshotSetID, &snapshotProperties)
if err != nil { if err != nil {
iVssBackupComponents.AbortBackup() _ = iVssBackupComponents.AbortBackup()
iVssBackupComponents.Release() iVssBackupComponents.Release()
return VssSnapshot{}, err return VssSnapshot{}, err
} }
for mountPoint, info := range mountPointInfo { for mountPoint, info := range mountPointInfo {
if !info.isSnapshotted { if !info.isSnapshotted {
continue continue
} }
@ -900,8 +1039,10 @@ func NewVssSnapshot(
mountPointInfo[mountPoint] = info mountPointInfo[mountPoint] = info
} }
return VssSnapshot{iVssBackupComponents, snapshotSetID, snapshotProperties, return VssSnapshot{
snapshotProperties.GetSnapshotDeviceObject(), mountPointInfo, timeoutInMillis}, nil iVssBackupComponents, snapshotSetID, snapshotProperties,
snapshotProperties.GetSnapshotDeviceObject(), mountPointInfo, time.Until(deadline),
}, nil
} }
// Delete deletes the created snapshot. // Delete deletes the created snapshot.
@ -922,15 +1063,17 @@ func (p *VssSnapshot) Delete() error {
if p.iVssBackupComponents != nil { if p.iVssBackupComponents != nil {
defer p.iVssBackupComponents.Release() defer p.iVssBackupComponents.Release()
deadline := time.Now().Add(p.timeout)
err = callAsyncFunctionAndWait(p.iVssBackupComponents.BackupComplete, "BackupComplete", err = callAsyncFunctionAndWait(p.iVssBackupComponents.BackupComplete, "BackupComplete",
p.timeoutInMillis) deadline)
if err != nil { if err != nil {
return err return err
} }
if _, _, e := p.iVssBackupComponents.DeleteSnapshots(p.snapshotID); e != nil { if _, _, e := p.iVssBackupComponents.DeleteSnapshots(p.snapshotID); e != nil {
err = newVssTextError(fmt.Sprintf("Failed to delete snapshot: %s", e.Error())) err = newVssTextError(fmt.Sprintf("Failed to delete snapshot: %s", e.Error()))
p.iVssBackupComponents.AbortBackup() _ = p.iVssBackupComponents.AbortBackup()
if err != nil { if err != nil {
return err return err
} }
@ -940,12 +1083,61 @@ func (p *VssSnapshot) Delete() error {
return nil return nil
} }
func getProviderID(provider string) (*ole.GUID, error) {
providerLower := strings.ToLower(provider)
switch providerLower {
case "":
return ole.IID_NULL, nil
case "ms":
return ole.NewGUID("{b5946137-7b9f-4925-af80-51abd60b20d5}"), nil
}
comInterface, err := ole.CreateInstance(CLSID_VSS_COORDINATOR, UIID_IVSS_ADMIN)
if err != nil {
return nil, err
}
defer comInterface.Release()
vssAdmin := (*IVSSAdmin)(unsafe.Pointer(comInterface))
enum, err := vssAdmin.QueryProviders()
if err != nil {
return nil, err
}
defer enum.Release()
id := ole.NewGUID(provider)
var props struct {
objectType uint32
provider VssProviderProperties
}
for {
count, err := enum.Next(1, unsafe.Pointer(&props))
if err != nil {
return nil, err
}
if count < 1 {
return nil, errors.Errorf(`invalid VSS provider "%s"`, provider)
}
name := ole.UTF16PtrToString(props.provider.providerName)
vssFreeProviderProperties(&props.provider)
if id != nil && *id == props.provider.providerID ||
id == nil && providerLower == strings.ToLower(name) {
return &props.provider.providerID, nil
}
}
}
// asyncCallFunc is the callback type for callAsyncFunctionAndWait. // asyncCallFunc is the callback type for callAsyncFunctionAndWait.
type asyncCallFunc func() (*IVSSAsync, error) type asyncCallFunc func() (*IVSSAsync, error)
// callAsyncFunctionAndWait calls an async functions and waits for it to either // callAsyncFunctionAndWait calls an async functions and waits for it to either
// finish or timeout. // finish or timeout.
func callAsyncFunctionAndWait(function asyncCallFunc, name string, timeoutInMillis uint32) error { func callAsyncFunctionAndWait(function asyncCallFunc, name string, deadline time.Time) error {
iVssAsync, err := function() iVssAsync, err := function()
if err != nil { if err != nil {
return err return err
@ -955,7 +1147,12 @@ func callAsyncFunctionAndWait(function asyncCallFunc, name string, timeoutInMill
return newVssTextError(fmt.Sprintf("%s() returned nil", name)) return newVssTextError(fmt.Sprintf("%s() returned nil", name))
} }
err = iVssAsync.WaitUntilAsyncFinished(timeoutInMillis) timeout := time.Until(deadline)
if timeout <= 0 {
return newVssTextError(fmt.Sprintf("%s() deadline exceeded", name))
}
err = iVssAsync.WaitUntilAsyncFinished(timeout)
iVssAsync.Release() iVssAsync.Release()
return err return err
} }
@ -1036,6 +1233,7 @@ func enumerateMountedFolders(volume string) ([]string, error) {
return mountedFolders, nil return mountedFolders, nil
} }
// nolint:errcheck
defer windows.FindVolumeMountPointClose(handle) defer windows.FindVolumeMountPointClose(handle)
volumeMountPoint := syscall.UTF16ToString(volumeMountPointBuffer) volumeMountPoint := syscall.UTF16ToString(volumeMountPointBuffer)