diff --git a/main.go b/main.go index e96d94b..570285e 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,9 @@ var threadCount = flag.Int("threads", 1024, "how many objects to operate on at a func main() { flag.Parse() + updateExpiry := time.Now().Add(time.Second * time.Duration(*updateExpiresWithin)) + lockExpiry := time.Now().Add(time.Second * time.Duration(*lockFor)) + options := []func(*config.LoadOptions) error{} if *region != "" { @@ -70,7 +73,7 @@ func main() { wg.Add(1) go func() { defer wg.Done() - queueWorker(svc, objectQueue) + queueWorker(svc, updateExpiry, lockExpiry, objectQueue) }() } @@ -89,17 +92,17 @@ func main() { wg.Wait() } -func queueWorker(svc *s3.Client, inQueue chan string) { +func queueWorker(svc *s3.Client, updateExpiry, lockExpiry time.Time, inQueue chan string) { for { object, more := <-inQueue if !more { return } - checkAndRenewObjectLock(svc, object) + checkAndRenewObjectLock(svc, updateExpiry, lockExpiry, object) } } -func checkAndRenewObjectLock(svc *s3.Client, object string) { +func checkAndRenewObjectLock(svc *s3.Client, updateExpiry, lockExpiry time.Time, object string) { updateHold := false if *updateExpiresWithin == 0 { updateHold = true @@ -108,9 +111,7 @@ func checkAndRenewObjectLock(svc *s3.Client, object string) { Bucket: bucket, Key: &object, }) - if retention == nil { - updateHold = true - } else if retention.Retention.RetainUntilDate.Before(time.Now().Add(time.Second * time.Duration(*updateExpiresWithin))) { + if retention == nil || retention.Retention.RetainUntilDate.Before(updateExpiry) { updateHold = true } } @@ -121,11 +122,13 @@ func checkAndRenewObjectLock(svc *s3.Client, object string) { Bucket: bucket, Key: &object, Retention: &types.ObjectLockRetention{ + // TODO: add flag for governance mode Mode: "COMPLIANCE", - RetainUntilDate: aws.Time(time.Now().Add(time.Second * time.Duration(*lockFor))), + RetainUntilDate: aws.Time(lockExpiry), }, }) if err != nil { + // TODO: handle 403 for when object already has a longer-lasting hold log.Fatalln("Failed to update retention for", object, err) } }