Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add openai client config arguments #426

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions docs/model_calling/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,81 @@ from
) x
;
```
### OpenAI Client Configuration

Client configuration is supported through the `ai.openai_client_config()`
function. This function generates a configuration object that can be used
across OpenAI-related functions to customize the client behavior.

The `ai.openai_client_config` function accepts the following parameters:

- `base_url` (text, optional): The base URL for the OpenAI API.
- `timeout_seconds` (float8, optional): Request timeout in seconds.
OpenAI-related functions accepts a timeout_seconds parameter that takes
precedence over this value.
- `organization` (text, optional): OpenAI organization ID.
- `project` (text, optional): Project identifier for tracking purposes.
- `max_retries` (int, optional): Maximum number of retry attempts for failed
requests.
- `default_headers` (jsonb, optional): Default headers to include in all
requests.
- `default_query` (jsonb, optional): Default query parameters to include in all
requests.

All parameters are optional and will use their default values defined by the
[OpenAI python library][openai-python-lib] if not provided.

For example, to set the project, and organization for OpenAI requests:

```sql
SELECT ai.openai_embed(
model => 'text-embedding-ada-002',
input_text => 'Hello world',
client_config => ai.openai_client_config(
organization => 'org-0F65JvbWoebpWcrboA6vR2zP',
project => 'my-pgai-project'
)
);
```

Setting the timeout and retries for OpenAI requests:

```sql
SELECT ai.openai_moderate(
input_text => 'Check this content',
client_config => ai.openai_client_config(
max_retries => 3,
timeout_seconds => 45.0
)
);
```

#### Base URL Migration from versions prior to 0.9.0

Version 0.9.0 introduced a new way to configure the base URL for OpenAI
requests. Previously, the base URL was passed as a separate parameter to the
OpenAI functions:

```sql
SELECT ai.openai_embed(
model => 'text-embedding-ada-002',
input_text => 'Hello world',
base_url => 'https://api.openai.com/v1'
);
```

Starting from version 0.9.0, the base URL is configured using the
`ai.openai_client_config` function:

```sql
SELECT ai.openai_embed(
model => 'text-embedding-ada-002',
input_text => 'Hello world',
client_config => ai.openai_client_config(
base_url => 'https://custom-openai-api.com/v1'
)
);
```

### Raw response

Expand All @@ -531,3 +606,4 @@ Python library and used in the same way as described in the [OpenAI Python
library documentation for Undocumented request params][undocumented-params].

[undocumented-params]: https://openai.com/docs/api-reference/python#undocumented-request-params
[openai-python-lib]: https://github.com/openai/openai-python
35 changes: 19 additions & 16 deletions projects/extension/ai/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from collections.abc import Generator
from datetime import datetime
from typing import Any

import openai

Expand All @@ -19,11 +20,13 @@ def get_openai_base_url(plpy) -> str | None:
def make_client(
plpy,
api_key: str,
base_url: str | None = None,
client_config: dict[str, Any] | None = None,
) -> openai.Client:
if base_url is None:
base_url = get_openai_base_url(plpy)
return openai.Client(api_key=api_key, base_url=base_url)
if client_config is None:
client_config = {}
if "base_url" not in client_config:
client_config["base_url"] = get_openai_base_url(plpy)
return openai.Client(api_key=api_key, **client_config)


def str_arg_to_dict(arg: str | None) -> dict | None:
Expand All @@ -41,18 +44,18 @@ def create_kwargs(**kwargs) -> dict:
def list_models(
plpy,
api_key: str,
base_url: str | None = None,
client_config: dict[str, Any] | None = None,
extra_headers: str | None = None,
extra_query: str | None = None,
timeout: float | None = None,
timeout_seconds: float | None = None,
) -> Generator[tuple[str, datetime, str], None, None]:
client = make_client(plpy, api_key, base_url)
client = make_client(plpy, api_key, client_config)
from datetime import datetime, timezone

kwargs = create_kwargs(
extra_headers=str_arg_to_dict(extra_headers),
extra_query=str_arg_to_dict(extra_query),
timeout=timeout,
timeout=timeout_seconds,
)

for model in client.models.list(**kwargs):
Expand All @@ -62,26 +65,26 @@ def list_models(

def embed(
plpy,
client_config: dict[str, Any] | None,
model: str,
input: str | list[str] | list[int],
api_key: str,
base_url: str | None = None,
dimensions: int | None = None,
user: str | None = None,
extra_headers: str | None = None,
extra_query: str | None = None,
extra_body: str | None = None,
timeout: float | None = None,
timeout_seconds: float | None = None,
) -> Generator[tuple[int, list[float]], None, None]:
client = make_client(plpy, api_key, base_url)
client = make_client(plpy, api_key, client_config)

kwargs = create_kwargs(
dimensions=dimensions,
user=user,
extra_headers=str_arg_to_dict(extra_headers),
extra_query=str_arg_to_dict(extra_query),
extra_body=str_arg_to_dict(extra_body),
timeout=timeout,
timeout=timeout_seconds,
)
response = client.embeddings.create(input=input, model=model, **kwargs)
if not hasattr(response, "data"):
Expand All @@ -92,26 +95,26 @@ def embed(

def embed_with_raw_response(
plpy,
client_config: dict[str, Any] | None,
model: str,
input: str | list[str] | list[int],
api_key: str,
base_url: str | None = None,
dimensions: int | None = None,
user: str | None = None,
extra_headers: str | None = None,
extra_query: str | None = None,
extra_body: str | None = None,
timeout: float | None = None,
timeout_seconds: float | None = None,
) -> str:
client = make_client(plpy, api_key, base_url)
client = make_client(plpy, api_key, client_config)

kwargs = create_kwargs(
dimensions=dimensions,
user=user,
extra_headers=str_arg_to_dict(extra_headers),
extra_query=str_arg_to_dict(extra_query),
extra_body=str_arg_to_dict(extra_body),
timeout=timeout,
timeout=timeout_seconds,
)
response = client.embeddings.with_raw_response.create(
input=input, model=model, **kwargs
Expand Down
Loading