From 56a0a0f59380ba56f020abd925e4731e7b87465f Mon Sep 17 00:00:00 2001 From: kiraksi Date: Mon, 11 Mar 2024 13:49:25 -0700 Subject: [PATCH] add async _call_api, RowIterator and get_job to implementation --- google/cloud/bigquery/async_client.py | 155 +++++++++++++++--- .../cloud/bigquery/opentelemetry_tracing.py | 33 +++- noxfile.py | 12 +- setup.py | 1 - tests/unit/test_async_client.py | 99 +++++++++++ 5 files changed, 272 insertions(+), 28 deletions(-) diff --git a/google/cloud/bigquery/async_client.py b/google/cloud/bigquery/async_client.py index 81bb9a197..3dc7632d5 100644 --- a/google/cloud/bigquery/async_client.py +++ b/google/cloud/bigquery/async_client.py @@ -1,6 +1,12 @@ from google.cloud.bigquery.client import * +from google.cloud.bigquery.client import ( + _add_server_timeout_header, + _extract_job_reference, +) +from google.cloud.bigquery.opentelemetry_tracing import async_create_span from google.cloud.bigquery import _job_helpers -from google.cloud.bigquery import table +from google.cloud.bigquery.table import * +from google.api_core.page_iterator import HTTPIterator from google.cloud.bigquery.retry import ( DEFAULT_ASYNC_JOB_RETRY, DEFAULT_ASYNC_RETRY, @@ -8,12 +14,54 @@ ) from google.api_core import retry_async as retries import asyncio +from google.auth.transport import _aiohttp_requests + +# This code is experimental class AsyncClient: def __init__(self, *args, **kwargs): self._client = Client(*args, **kwargs) + async def get_job( + self, + job_id: Union[str, job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob], + project: Optional[str] = None, + location: Optional[str] = None, + retry: retries.AsyncRetry = DEFAULT_ASYNC_RETRY, + timeout: TimeoutType = DEFAULT_TIMEOUT, + ) -> Union[job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob, job.UnknownJob]: + extra_params = {"projection": "full"} + + project, location, job_id = _extract_job_reference( + job_id, project=project, location=location + ) + + if project is None: + project = self._client.project + + if location is None: + location = self._client.location + + if location is not None: + extra_params["location"] = location + + path = "/projects/{}/jobs/{}".format(project, job_id) + + span_attributes = {"path": path, "job_id": job_id, "location": location} + + resource = await self._call_api( + retry, + span_name="BigQuery.getJob", + span_attributes=span_attributes, + method="GET", + path=path, + query_params=extra_params, + timeout=timeout, + ) + + return await asyncio.to_thread(self._client.job_from_resource(await resource)) + async def query_and_wait( self, query, @@ -46,7 +94,7 @@ async def query_and_wait( ) return await async_query_and_wait( - self._client, + self, query, job_config=job_config, location=location, @@ -59,9 +107,41 @@ async def query_and_wait( max_results=max_results, ) + async def _call_api( + self, + retry: Optional[retries.AsyncRetry] = None, + span_name: Optional[str] = None, + span_attributes: Optional[Dict] = None, + job_ref=None, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ): + kwargs = _add_server_timeout_header(headers, kwargs) + + # Prepare the asynchronous request function + # async with _aiohttp_requests.Request(**kwargs) as response: + # response.raise_for_status() + # response = await response.json() # or response.text() + + async_call = functools.partial(self._client._connection.api_request, **kwargs) + + if retry: + async_call = retry(async_call) + + if span_name is not None: + async with async_create_span( + name=span_name, + attributes=span_attributes, + client=self._client, + job_ref=job_ref, + ): + return async_call() # Await the asynchronous call + + return async_call() # Await the asynchronous call + async def async_query_and_wait( - client: "Client", + client: "AsyncClient", query: str, *, job_config: Optional[job.QueryJobConfig], @@ -73,14 +153,12 @@ async def async_query_and_wait( job_retry: Optional[retries.AsyncRetry], page_size: Optional[int] = None, max_results: Optional[int] = None, -) -> table.RowIterator: - # Some API parameters aren't supported by the jobs.query API. In these - # cases, fallback to a jobs.insert call. +) -> RowIterator: if not _job_helpers._supported_by_jobs_query(job_config): return await async_wait_or_cancel( asyncio.to_thread( _job_helpers.query_jobs_insert( - client=client, + client=client._client, query=query, job_id=None, job_id_prefix=None, @@ -116,7 +194,7 @@ async def async_query_and_wait( span_attributes = {"path": path} if retry is not None: - response = client._call_api( # ASYNCHRONOUS HTTP CALLS aiohttp (optional of google-auth), add back retry() + response = await client._call_api( # ASYNCHRONOUS HTTP CALLS aiohttp (optional of google-auth), add back retry() retry=None, # We're calling the retry decorator ourselves, async_retries, need to implement after making HTTP calls async span_name="BigQuery.query", span_attributes=span_attributes, @@ -127,7 +205,7 @@ async def async_query_and_wait( ) else: - response = client._call_api( + response = await client._call_api( retry=None, span_name="BigQuery.query", span_attributes=span_attributes, @@ -149,17 +227,28 @@ async def async_query_and_wait( # client._list_rows_from_query_results directly. Need to update # RowIterator to fetch destination table via the job ID if needed. result = await async_wait_or_cancel( - _job_helpers._to_query_job(client, query, job_config, response), - api_timeout=api_timeout, - wait_timeout=wait_timeout, - retry=retry, - page_size=page_size, - max_results=max_results, + asyncio.to_thread( + _job_helpers._to_query_job(client._client, query, job_config, response), + api_timeout=api_timeout, + wait_timeout=wait_timeout, + retry=retry, + page_size=page_size, + max_results=max_results, + ) + ) + + def api_request(*args, **kwargs): + return client._call_api( + span_name="BigQuery.query", + span_attributes=span_attributes, + *args, + timeout=api_timeout, + **kwargs, ) - result = table.RowIterator( # async of RowIterator? async version without all the pandas stuff - client=client, - api_request=functools.partial(client._call_api, retry, timeout=api_timeout), + result = AsyncRowIterator( # async of RowIterator? async version without all the pandas stuff + client=client._client, + api_request=api_request, path=None, schema=query_results.schema, max_results=max_results, @@ -186,10 +275,10 @@ async def async_wait_or_cancel( retry: Optional[retries.AsyncRetry], page_size: Optional[int], max_results: Optional[int], -) -> table.RowIterator: +) -> RowIterator: try: return asyncio.to_thread( - job.result( # run in a background thread + job.result( page_size=page_size, max_results=max_results, retry=retry, @@ -204,3 +293,29 @@ async def async_wait_or_cancel( # Don't eat the original exception if cancel fails. pass raise + + +class AsyncRowIterator(RowIterator): + async def _get_next_page_response(self): + """Asynchronous version of fetching the next response page.""" + if self._first_page_response: + rows = self._first_page_response.get(self._items_key, [])[ + : self.max_results + ] + response = { + self._items_key: rows, + } + if self._next_token in self._first_page_response: + response[self._next_token] = self._first_page_response[self._next_token] + + self._first_page_response = None + return response + + params = self._get_query_params() + if self._page_size is not None: + if self.page_number and "startIndex" in params: + del params["startIndex"] + params["maxResults"] = self._page_size + return await self.api_request( + method=self._HTTP_METHOD, path=self.path, query_params=params + ) diff --git a/google/cloud/bigquery/opentelemetry_tracing.py b/google/cloud/bigquery/opentelemetry_tracing.py index e2a05e4d0..c1594c1a2 100644 --- a/google/cloud/bigquery/opentelemetry_tracing.py +++ b/google/cloud/bigquery/opentelemetry_tracing.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager from google.api_core.exceptions import GoogleAPICallError # type: ignore logger = logging.getLogger(__name__) @@ -86,6 +86,37 @@ def create_span(name, attributes=None, client=None, job_ref=None): raise +@asynccontextmanager +async def async_create_span(name, attributes=None, client=None, job_ref=None): + """Asynchronous context manager for creating and exporting OpenTelemetry spans.""" + global _warned_telemetry + final_attributes = _get_final_span_attributes(attributes, client, job_ref) + + if not HAS_OPENTELEMETRY: + if not _warned_telemetry: + logger.debug( + "This service is instrumented using OpenTelemetry. " + "OpenTelemetry or one of its components could not be imported; " + "please add compatible versions of opentelemetry-api and " + "opentelemetry-instrumentation packages in order to get BigQuery " + "Tracing data." + ) + _warned_telemetry = True + yield None + return + tracer = trace.get_tracer(__name__) + + async with tracer.start_as_current_span( + name=name, attributes=final_attributes + ) as span: + try: + yield span + except GoogleAPICallError as error: + if error.code is not None: + span.set_status(Status(http_status_to_status_code(error.code))) + raise + + def _get_final_span_attributes(attributes=None, client=None, job_ref=None): """Compiles attributes from: client, job_ref, user-provided attributes. diff --git a/noxfile.py b/noxfile.py index c31d098b8..26d55111f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -80,8 +80,8 @@ def default(session, install_extras=True): constraints_path, ) - if install_extras and session.python in ["3.11", "3.12"]: - install_target = ".[bqstorage,ipywidgets,pandas,tqdm,opentelemetry]" + if install_extras and session.python in ["3.12"]: + install_target = ".[bqstorage,ipywidgets,pandas,tqdm,opentelemetry,aiohttp]" elif install_extras: install_target = ".[all]" else: @@ -188,8 +188,8 @@ def system(session): # Data Catalog needed for the column ACL test with a real Policy Tag. session.install("google-cloud-datacatalog", "-c", constraints_path) - if session.python in ["3.11", "3.12"]: - extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry]" + if session.python in ["3.12"]: + extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry,aiohttp]" # look at geopandas to see if it supports 3.11/3.12 (up to 3.11) else: extras = "[all]" session.install("-e", f".{extras}", "-c", constraints_path) @@ -254,8 +254,8 @@ def snippets(session): session.install("google-cloud-storage", "-c", constraints_path) session.install("grpcio", "-c", constraints_path) - if session.python in ["3.11", "3.12"]: - extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry]" + if session.python in ["3.12"]: + extras = "[bqstorage,ipywidgets,pandas,tqdm,opentelemetry,aiohttp]" else: extras = "[all]" session.install("-e", f".{extras}", "-c", constraints_path) diff --git a/setup.py b/setup.py index 9f6fabcfc..7d672d239 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,6 @@ # NOTE: Maintainers, please do not require google-cloud-core>=2.x.x # Until this issue is closed # https://github.com/googleapis/google-cloud-python/issues/10566 - "google-auth >= 2.14.1, <3.0.0dev", "google-cloud-core >= 1.6.0, <3.0.0dev", "google-resumable-media >= 0.6.0, < 3.0dev", "packaging >= 20.0.0", diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py index 472504711..e500c6340 100644 --- a/tests/unit/test_async_client.py +++ b/tests/unit/test_async_client.py @@ -297,6 +297,105 @@ def test_ctor_w_load_job_config(self): self.assertIsInstance(client._default_load_job_config, LoadJobConfig) self.assertTrue(client._default_load_job_config.create_session) + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_get_job_miss_w_explict_project(self): + from google.cloud.exceptions import NotFound + + OTHER_PROJECT = "OTHER_PROJECT" + JOB_ID = "NONESUCH" + creds = _make_credentials() + client = self._make_one(self.PROJECT, creds) + conn = client._client._connection = make_connection() + + with self.assertRaises(NotFound): + await client.get_job(JOB_ID, project=OTHER_PROJECT) + + conn.api_request.assert_called_once_with( + method="GET", + path="/projects/OTHER_PROJECT/jobs/NONESUCH", + query_params={"projection": "full"}, + timeout=DEFAULT_TIMEOUT, + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_get_job_miss_w_client_location(self): + from google.cloud.exceptions import NotFound + + JOB_ID = "NONESUCH" + creds = _make_credentials() + client = self._make_one("client-proj", creds, location="client-loc") + conn = client._client._connection = make_connection() + + with self.assertRaises(NotFound): + await client.get_job(JOB_ID) + + conn.api_request.assert_called_once_with( + method="GET", + path="/projects/client-proj/jobs/NONESUCH", + query_params={"projection": "full", "location": "client-loc"}, + timeout=DEFAULT_TIMEOUT, + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + @asyncio_run + async def test_get_job_hit_w_timeout(self): + from google.cloud.bigquery.job import CreateDisposition + from google.cloud.bigquery.job import QueryJob + from google.cloud.bigquery.job import WriteDisposition + + JOB_ID = "query_job" + QUERY_DESTINATION_TABLE = "query_destination_table" + QUERY = "SELECT * from test_dataset:test_table" + ASYNC_QUERY_DATA = { + "id": "{}:{}".format(self.PROJECT, JOB_ID), + "jobReference": { + "projectId": "resource-proj", + "jobId": "query_job", + "location": "us-east1", + }, + "state": "DONE", + "configuration": { + "query": { + "query": QUERY, + "destinationTable": { + "projectId": self.PROJECT, + "datasetId": self.DS_ID, + "tableId": QUERY_DESTINATION_TABLE, + }, + "createDisposition": CreateDisposition.CREATE_IF_NEEDED, + "writeDisposition": WriteDisposition.WRITE_TRUNCATE, + } + }, + } + creds = _make_credentials() + client = self._make_one(self.PROJECT, creds) + conn = client._client._connection = make_connection(ASYNC_QUERY_DATA) + job_from_resource = QueryJob.from_api_repr(ASYNC_QUERY_DATA, client._client) + + job = await client.get_job(job_from_resource, timeout=7.5) + + self.assertIsInstance(job, QueryJob) + self.assertEqual(job.job_id, JOB_ID) + self.assertEqual(job.project, "resource-proj") + self.assertEqual(job.location, "us-east1") + self.assertEqual(job.create_disposition, CreateDisposition.CREATE_IF_NEEDED) + self.assertEqual(job.write_disposition, WriteDisposition.WRITE_TRUNCATE) + + conn.api_request.assert_called_once_with( + method="GET", + path="/projects/resource-proj/jobs/query_job", + query_params={"projection": "full", "location": "us-east1"}, + timeout=7.5, + ) + @pytest.mark.skipif( sys.version_info < (3, 9), reason="requires python3.9 or higher" )