import hashlib import logging import xml.etree.ElementTree as ET import httpx from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials from frostfs_testlib import reporter from frostfs_testlib.clients import HttpClient from frostfs_testlib.utils.file_utils import TestFile logger = logging.getLogger("NeoLogger") DEFAULT_TIMEOUT = 60.0 class S3HttpClient: def __init__( self, s3gate_endpoint: str, access_key_id: str, secret_access_key: str, profile: str = "default", region: str = "us-east-1" ) -> None: self.http_client = HttpClient() self.credentials = Credentials(access_key_id, secret_access_key) self.profile = profile self.region = region self.iam_endpoint: str = None self.s3gate_endpoint: str = None self.service: str = None self.signature: SigV4Auth = None self.set_endpoint(s3gate_endpoint) def _to_s3_header(self, header: str) -> dict: replacement_map = { "Acl": "ACL", "_": "-", } result = header if not header.startswith("x_amz"): result = header.title() for find, replace in replacement_map.items(): result = result.replace(find, replace) return result def _convert_to_s3_headers(self, scope: dict, exclude: list[str] = None): exclude = ["self", "cls"] if not exclude else exclude + ["self", "cls"] return {self._to_s3_header(header): value for header, value in scope.items() if header not in exclude and value is not None} def _create_aws_request( self, method: str, url: str, headers: dict, content: str | bytes | TestFile = None, params: dict = None ) -> AWSRequest: data = b"" if content is not None: if isinstance(content, TestFile): with open(content, "rb") as io_content: data = io_content.read() elif isinstance(content, str): data = bytes(content, encoding="utf-8") elif isinstance(content, bytes): data = content else: raise TypeError(f"Content expected as a string, bytes or TestFile object, got: {content}") headers["X-Amz-Content-SHA256"] = hashlib.sha256(data).hexdigest() aws_request = AWSRequest(method, url, headers, data, params) self.signature.add_auth(aws_request) return aws_request def _exec_request( self, method: str, url: str, headers: dict, content: str | bytes | TestFile = None, params: dict = None, timeout: float = DEFAULT_TIMEOUT, ) -> dict: aws_request = self._create_aws_request(method, url, headers, content, params) response = self.http_client.send( aws_request.method, aws_request.url, headers=dict(aws_request.headers), data=aws_request.data, params=aws_request.params, timeout=timeout, ) try: response.raise_for_status() except httpx.HTTPStatusError: raise httpx.HTTPStatusError(response.text, request=response.request, response=response) root = ET.fromstring(response.read()) data = { "LastModified": root.find(".//LastModified").text, "ETag": root.find(".//ETag").text, } if response.headers.get("x-amz-version-id"): data["VersionId"] = response.headers.get("x-amz-version-id") return data @reporter.step("Set endpoint S3 to {s3gate_endpoint}") def set_endpoint(self, s3gate_endpoint: str): if self.s3gate_endpoint == s3gate_endpoint: return self.s3gate_endpoint = s3gate_endpoint self.service = "s3" self.signature = SigV4Auth(self.credentials, self.service, self.region) @reporter.step("Set endpoint IAM to {iam_endpoint}") def set_iam_endpoint(self, iam_endpoint: str): if self.iam_endpoint == iam_endpoint: return self.iam_endpoint = iam_endpoint self.service = "iam" self.signature = SigV4Auth(self.credentials, self.service, self.region) @reporter.step("Patch object S3") def patch_object( self, bucket: str, key: str, content: str | bytes | TestFile, content_range: str, version_id: str = None, if_match: str = None, if_unmodified_since: str = None, x_amz_expected_bucket_owner: str = None, timeout: float = DEFAULT_TIMEOUT, ) -> dict: if content_range and not content_range.startswith("bytes"): content_range = f"bytes {content_range}/*" url = f"{self.s3gate_endpoint}/{bucket}/{key}" headers = self._convert_to_s3_headers(locals(), exclude=["bucket", "key", "content", "version_id", "timeout"]) params = {"VersionId": version_id} if version_id is not None else None return self._exec_request("PATCH", url, headers, content, params, timeout=timeout)