Fix review comments

Change lowerPrivileges from bool to atomic.Bool.
Add missing cleanup from upstream go-winio.
Add handling for ERROR_NOT_ALL_ASSIGNED warning.
This commit is contained in:
aneesh-n 2024-05-06 16:54:08 -06:00
parent 672f6cd776
commit a4fd1b91e5
No known key found for this signature in database
GPG key ID: 6F5A52831C046F44

View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"syscall" "syscall"
"unicode/utf16" "unicode/utf16"
"unsafe" "unsafe"
@ -26,7 +27,7 @@ var (
// SeTakeOwnershipPrivilege allows the application to take ownership of files and directories, regardless of the permissions set on them. // SeTakeOwnershipPrivilege allows the application to take ownership of files and directories, regardless of the permissions set on them.
SeTakeOwnershipPrivilege = "SeTakeOwnershipPrivilege" SeTakeOwnershipPrivilege = "SeTakeOwnershipPrivilege"
lowerPrivileges bool lowerPrivileges atomic.Bool
) )
// Flags for backup and restore with admin permissions // Flags for backup and restore with admin permissions
@ -46,14 +47,15 @@ func GetSecurityDescriptor(filePath string) (securityDescriptor *[]byte, err err
var sd *windows.SECURITY_DESCRIPTOR var sd *windows.SECURITY_DESCRIPTOR
if lowerPrivileges { if lowerPrivileges.Load() {
sd, err = getNamedSecurityInfoLow(filePath) sd, err = getNamedSecurityInfoLow(filePath)
} else { } else {
sd, err = getNamedSecurityInfoHigh(filePath) sd, err = getNamedSecurityInfoHigh(filePath)
} }
if err != nil { if err != nil {
if !lowerPrivileges && isHandlePrivilegeNotHeldError(err) { if !lowerPrivileges.Load() && isHandlePrivilegeNotHeldError(err) {
lowerPrivileges = true // If ERROR_PRIVILEGE_NOT_HELD is encountered, fallback to backups/restores using lower non-admin privileges.
lowerPrivileges.Store(true)
sd, err = getNamedSecurityInfoLow(filePath) sd, err = getNamedSecurityInfoLow(filePath)
if err != nil { if err != nil {
return nil, fmt.Errorf("get low-level named security info failed with: %w", err) return nil, fmt.Errorf("get low-level named security info failed with: %w", err)
@ -104,16 +106,16 @@ func SetSecurityDescriptor(filePath string, securityDescriptor *[]byte) error {
sacl = nil sacl = nil
} }
if lowerPrivileges { if lowerPrivileges.Load() {
err = setNamedSecurityInfoLow(filePath, dacl) err = setNamedSecurityInfoLow(filePath, dacl)
} else { } else {
err = setNamedSecurityInfoHigh(filePath, owner, group, dacl, sacl) err = setNamedSecurityInfoHigh(filePath, owner, group, dacl, sacl)
} }
if err != nil { if err != nil {
if isHandlePrivilegeNotHeldError(err) { if !lowerPrivileges.Load() && isHandlePrivilegeNotHeldError(err) {
// If ERROR_PRIVILEGE_NOT_HELD is encountered, fallback to backups/restores using lower non-admin privileges. // If ERROR_PRIVILEGE_NOT_HELD is encountered, fallback to backups/restores using lower non-admin privileges.
lowerPrivileges = true lowerPrivileges.Store(true)
err = setNamedSecurityInfoLow(filePath, dacl) err = setNamedSecurityInfoLow(filePath, dacl)
if err != nil { if err != nil {
return fmt.Errorf("set low-level named security info failed with: %w", err) return fmt.Errorf("set low-level named security info failed with: %w", err)
@ -231,7 +233,7 @@ const (
SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED
//revive:disable-next-line:var-naming ALL_CAPS //revive:disable-next-line:var-naming ALL_CAPS
ERROR_NOT_ALL_ASSIGNED syscall.Errno = windows.ERROR_NOT_ALL_ASSIGNED ERROR_NOT_ALL_ASSIGNED windows.Errno = windows.ERROR_NOT_ALL_ASSIGNED
) )
var ( var (
@ -287,11 +289,6 @@ func enableProcessPrivileges(names []string) error {
return enableDisableProcessPrivilege(names, SE_PRIVILEGE_ENABLED) return enableDisableProcessPrivilege(names, SE_PRIVILEGE_ENABLED)
} }
// DisableProcessPrivileges disables privileges globally for the process.
func DisableProcessPrivileges(names []string) error {
return enableDisableProcessPrivilege(names, 0)
}
func enableDisableProcessPrivilege(names []string, action uint32) error { func enableDisableProcessPrivilege(names []string, action uint32) error {
privileges, err := mapPrivileges(names) privileges, err := mapPrivileges(names)
if err != nil { if err != nil {
@ -325,7 +322,7 @@ func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) e
return err return err
} }
if err == ERROR_NOT_ALL_ASSIGNED { //nolint:errorlint // err is Errno if err == ERROR_NOT_ALL_ASSIGNED { //nolint:errorlint // err is Errno
return &PrivilegeError{privileges} debug.Log("Not all requested privileges were fully set: %v. AdjustTokenPrivileges returned warning: %v", privileges, err)
} }
return nil return nil
} }
@ -349,6 +346,15 @@ func getPrivilegeName(luid uint64) string {
return string(utf16.Decode(displayNameBuffer[:displayBufSize])) return string(utf16.Decode(displayNameBuffer[:displayBufSize]))
} }
// The functions below are copied over from https://github.com/microsoft/go-winio/blob/main/zsyscall_windows.go
// This windows api always returns an error even in case of success, warnings (partial success) and error cases.
//
// Full success - When we call this with admin permissions, it returns DNS_ERROR_RCODE_NO_ERROR (0).
// This gets translated to errErrorEinval and ultimately in adjustTokenPrivileges, it gets ignored.
//
// Partial success - If we call this api without admin privileges, privileges related to SACLs do not get set and
// though the api returns success, it returns an error - golang.org/x/sys/windows.ERROR_NOT_ALL_ASSIGNED (1300)
func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) { func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) {
var _p0 uint32 var _p0 uint32
if releaseAll { if releaseAll {
@ -356,7 +362,7 @@ func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, ou
} }
r0, _, e1 := syscall.SyscallN(procAdjustTokenPrivileges.Addr(), uintptr(token), uintptr(_p0), uintptr(unsafe.Pointer(input)), uintptr(outputSize), uintptr(unsafe.Pointer(output)), uintptr(unsafe.Pointer(requiredSize))) r0, _, e1 := syscall.SyscallN(procAdjustTokenPrivileges.Addr(), uintptr(token), uintptr(_p0), uintptr(unsafe.Pointer(input)), uintptr(outputSize), uintptr(unsafe.Pointer(output)), uintptr(unsafe.Pointer(requiredSize)))
success = r0 != 0 success = r0 != 0
if !success { if true {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
@ -372,7 +378,7 @@ func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16,
} }
func _lookupPrivilegeDisplayName(systemName *uint16, name *uint16, buffer *uint16, size *uint32, languageID *uint32) (err error) { func _lookupPrivilegeDisplayName(systemName *uint16, name *uint16, buffer *uint16, size *uint32, languageID *uint32) (err error) {
r1, _, e1 := syscall.SyscallN(procLookupPrivilegeDisplayNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), uintptr(unsafe.Pointer(languageID)), 0) r1, _, e1 := syscall.SyscallN(procLookupPrivilegeDisplayNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), uintptr(unsafe.Pointer(languageID)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
@ -389,7 +395,7 @@ func lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *
} }
func _lookupPrivilegeName(systemName *uint16, luid *uint64, buffer *uint16, size *uint32) (err error) { func _lookupPrivilegeName(systemName *uint16, luid *uint64, buffer *uint16, size *uint32) (err error) {
r1, _, e1 := syscall.SyscallN(procLookupPrivilegeNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(luid)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), 0, 0) r1, _, e1 := syscall.SyscallN(procLookupPrivilegeNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(luid)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
@ -418,6 +424,8 @@ func _lookupPrivilegeValue(systemName *uint16, name *uint16, luid *uint64) (err
return return
} }
// The code below was copied from https://github.com/microsoft/go-winio/blob/main/tools/mkwinsyscall/mkwinsyscall.go
// errnoErr returns common boxed Errno values, to prevent // errnoErr returns common boxed Errno values, to prevent
// allocations at runtime. // allocations at runtime.
func errnoErr(e syscall.Errno) error { func errnoErr(e syscall.Errno) error {