Skip to content

Commit

Permalink
feat!: consolidate OpenAI client configuration
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Replace individual client configuration parameters with
a single `client_config` parameter. The `base_url` parameter has been
moved into the new configuration structure.

- Add new `ai.openai_client_config()` function to generate client
  configurations
- Remove `base_url` parameter from all OpenAI functions
- Add `client_config` parameter to all OpenAI functions
- Support additional client options: timeout_seconds, organization,
  project, max_retries, default_headers, default_query

Migration:
Old:
  SELECT ai.openai_embed(
    model => 'model',
    input_text => 'text',
    base_url => 'url'
  );

New:
  SELECT ai.openai_embed(
    model => 'model',
    input_text => 'text',
    client_config => ai.openai_client_config(base_url => 'url')
  );
  • Loading branch information
alejandrodnm committed Feb 6, 2025
1 parent bf33286 commit be5b0ee
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 74 deletions.
76 changes: 76 additions & 0 deletions docs/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,81 @@ from
) x
;
```
### OpenAI Client Configuration
Client configuration is supported through the `ai.openai_client_config()`
function. This function generates a configuration object that can be used
across OpenAI-related functions to customize the client behavior.
The `ai.openai_client_config` function accepts the following parameters:
- `base_url` (text, optional): The base URL for the OpenAI API.
- `timeout_seconds` (float8, optional): Request timeout in seconds.
OpenAI-related functions accepts a timeout_seconds parameter that takes
precedence over this value.
- `organization` (text, optional): OpenAI organization ID.
- `project` (text, optional): Project identifier for tracking purposes.
- `max_retries` (int, optional): Maximum number of retry attempts for failed
requests.
- `default_headers` (jsonb, optional): Default headers to include in all
requests.
- `default_query` (jsonb, optional): Default query parameters to include in all
requests.
All parameters are optional and will use their default values defined by the
[OpenAI python library][openai-python-lib] if not provided.
For example, to set the project, and organization for OpenAI requests:
```sql
SELECT ai.openai_embed(
model => 'text-embedding-ada-002',
input_text => 'Hello world',
client_config => ai.openai_client_config(
organization => 'org-0F65JvbWoebpWcrboA6vR2zP',
project => 'my-pgai-project'
)
);
```
Setting the timeout and retries for OpenAI requests:
```sql
SELECT ai.openai_moderate(
input_text => 'Check this content',
client_config => ai.openai_client_config(
max_retries => 3,
timeout_seconds => 45.0
)
);
```
#### Base URL Migration from versions prior to 0.9.0
Version 0.9.0 introduced a new way to configure the base URL for OpenAI
requests. Previously, the base URL was passed as a separate parameter to the
OpenAI functions:
```sql
SELECT ai.openai_embed(
model => 'text-embedding-ada-002',
input_text => 'Hello world',
base_url => 'https://api.openai.com/v1'
);
```
Starting from version 0.9.0, the base URL is configured using the
`ai.openai_client_config` function:
```sql
SELECT ai.openai_embed(
model => 'text-embedding-ada-002',
input_text => 'Hello world',
client_config => ai.openai_client_config(
base_url => 'https://custom-openai-api.com/v1'
)
);
```
### Raw response
Expand All @@ -531,3 +606,4 @@ Python library and used in the same way as described in the [OpenAI Python
library documentation for Undocumented request params][undocumented-params].
[undocumented-params]: https://openai.com/docs/api-reference/python#undocumented-request-params
[openai-python-lib]: https://github.com/openai/openai-python
35 changes: 19 additions & 16 deletions projects/extension/ai/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from collections.abc import Generator
from datetime import datetime
from typing import Any

import openai

Expand All @@ -19,11 +20,13 @@ def get_openai_base_url(plpy) -> str | None:
def make_client(
plpy,
api_key: str,
base_url: str | None = None,
client_config: dict[str, Any] | None = None,
) -> openai.Client:
if base_url is None:
base_url = get_openai_base_url(plpy)
return openai.Client(api_key=api_key, base_url=base_url)
if client_config is None:
client_config = {}
if "base_url" not in client_config:
client_config["base_url"] = get_openai_base_url(plpy)
return openai.Client(api_key=api_key, **client_config)


def str_arg_to_dict(arg: str | None) -> dict | None:
Expand All @@ -41,18 +44,18 @@ def create_kwargs(**kwargs) -> dict:
def list_models(
plpy,
api_key: str,
base_url: str | None = None,
client_config: dict[str, Any] | None = None,
extra_headers: str | None = None,
extra_query: str | None = None,
timeout: float | None = None,
timeout_seconds: float | None = None,
) -> Generator[tuple[str, datetime, str], None, None]:
client = make_client(plpy, api_key, base_url)
client = make_client(plpy, api_key, client_config)
from datetime import datetime, timezone

kwargs = create_kwargs(
extra_headers=str_arg_to_dict(extra_headers),
extra_query=str_arg_to_dict(extra_query),
timeout=timeout,
timeout=timeout_seconds,
)

for model in client.models.list(**kwargs):
Expand All @@ -62,26 +65,26 @@ def list_models(

def embed(
plpy,
client_config: dict[str, Any] | None,
model: str,
input: str | list[str] | list[int],
api_key: str,
base_url: str | None = None,
dimensions: int | None = None,
user: str | None = None,
extra_headers: str | None = None,
extra_query: str | None = None,
extra_body: str | None = None,
timeout: float | None = None,
timeout_seconds: float | None = None,
) -> Generator[tuple[int, list[float]], None, None]:
client = make_client(plpy, api_key, base_url)
client = make_client(plpy, api_key, client_config)

kwargs = create_kwargs(
dimensions=dimensions,
user=user,
extra_headers=str_arg_to_dict(extra_headers),
extra_query=str_arg_to_dict(extra_query),
extra_body=str_arg_to_dict(extra_body),
timeout=timeout,
timeout=timeout_seconds,
)
response = client.embeddings.create(input=input, model=model, **kwargs)
if not hasattr(response, "data"):
Expand All @@ -92,26 +95,26 @@ def embed(

def embed_with_raw_response(
plpy,
client_config: dict[str, Any] | None,
model: str,
input: str | list[str] | list[int],
api_key: str,
base_url: str | None = None,
dimensions: int | None = None,
user: str | None = None,
extra_headers: str | None = None,
extra_query: str | None = None,
extra_body: str | None = None,
timeout: float | None = None,
timeout_seconds: float | None = None,
) -> str:
client = make_client(plpy, api_key, base_url)
client = make_client(plpy, api_key, client_config)

kwargs = create_kwargs(
dimensions=dimensions,
user=user,
extra_headers=str_arg_to_dict(extra_headers),
extra_query=str_arg_to_dict(extra_query),
extra_body=str_arg_to_dict(extra_body),
timeout=timeout,
timeout=timeout_seconds,
)
response = client.embeddings.with_raw_response.create(
input=input, model=model, **kwargs
Expand Down
Loading

0 comments on commit be5b0ee

Please sign in to comment.