Skip to content

Commit

Permalink
feat: Updated to latest API
Browse files Browse the repository at this point in the history
- Updated the EF to the latest API
- Added a few more options
- Tests
  • Loading branch information
tazarov committed Apr 8, 2024
1 parent 0794be7 commit c0cbbed
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 59 deletions.
105 changes: 105 additions & 0 deletions chromadb/test/ef/test_voyageai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os

import pytest

from chromadb.utils.embedding_functions import VoyageAIEmbeddingFunction


def test_voyage() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(api_key=os.environ.get("VOYAGEAI_API_KEY", ""))
embeddings = ef(["test doc"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


def test_voyage_input_type_query() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGEAI_API_KEY", ""), input_type="query"
)
embeddings = ef(["test doc"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


def test_voyage_input_type_document() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGEAI_API_KEY", ""), input_type="document"
)
embeddings = ef(["test doc"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


def test_voyage_model() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGEAI_API_KEY", ""), model_name="voyage-code-2"
)
embeddings = ef(["def test():\n return 1"])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


def test_voyage_truncation_default() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(api_key=os.environ.get("VOYAGEAI_API_KEY", ""))
embeddings = ef(["this is a test-message" * 10000])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


def test_voyage_truncation_enabled() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGEAI_API_KEY", ""), truncation=True
)
embeddings = ef(["this is a test-message" * 10000])
assert embeddings is not None
assert len(embeddings) == 1
assert len(embeddings[0]) > 0


def test_voyage_truncation_disabled() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(
api_key=os.environ.get("VOYAGEAI_API_KEY", ""), truncation=False
)
with pytest.raises(Exception, match="your batch has too many tokens"):
ef(["this is a test-message" * 10000])


def test_voyage_no_api_key() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
with pytest.raises(ValueError, match="Please provide a VoyageAI API key"):
VoyageAIEmbeddingFunction(api_key=None) # type: ignore


def test_voyage_max_batch_size_exceeded_in_init() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
with pytest.raises(ValueError, match="The maximum batch size supported is"):
VoyageAIEmbeddingFunction(api_key="dummy", max_batch_size=99999999)


def test_voyage_max_batch_size_exceeded_in_call() -> None:
if "VOYAGEAI_API_KEY" not in os.environ:
pytest.skip("VOYAGEAI_API_KEY not set, not going to test VoyageAI EF.")
ef = VoyageAIEmbeddingFunction(api_key="dummy", max_batch_size=1)
with pytest.raises(ValueError, match="The maximum batch size supported is"):
ef(["test doc"] * 2)
131 changes: 72 additions & 59 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,9 +743,7 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:


class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]):
def __init__(
self, api_key: str = "", api_url = "https://infer.roboflow.com"
) -> None:
def __init__(self, api_key: str = "", api_url="https://infer.roboflow.com") -> None:
"""
Create a RoboflowEmbeddingFunction.
Expand All @@ -757,7 +755,7 @@ def __init__(
api_key = os.environ.get("ROBOFLOW_API_KEY")

self._api_url = api_url
self._api_key = api_key
self._api_key = api_key

try:
self._PILImage = importlib.import_module("PIL.Image")
Expand Down Expand Up @@ -789,10 +787,10 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
json=infer_clip_payload,
)

result = res.json()['embeddings']
result = res.json()["embeddings"]

embeddings.append(result[0])

elif is_document(item):
infer_clip_payload = {
"text": input,
Expand All @@ -803,13 +801,13 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
json=infer_clip_payload,
)

result = res.json()['embeddings']
result = res.json()["embeddings"]

embeddings.append(result[0])

return embeddings


class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
Expand Down Expand Up @@ -899,55 +897,70 @@ def __call__(self, input: Documents) -> Embeddings:
Embeddings, self._session.post(self._api_url, json={"inputs": input}).json()
)


class VoyageAIEmbeddingFunction(EmbeddingFunction):
"""Embedding function for Voyageai.com"""
def __init__(self, api_key: str, model_name: str = "voyage-01", batch_size: int = 8):
"""
Initialize the VoyageAIEmbeddingFunction.
Args:
api_key (str): Your API key for the HuggingFace API.
model_name (str, optional): The name of the model to use for text embeddings. Defaults to "voyage-01".
batch_size (int, optional): The number of documents to send at a time. Defaults to 8 (The max supported 3rd Nov 2023).
"""
if batch_size > 8:
print(f"Voyage AI as of (3rd Nov 2023) has a batch size of max 8")

if not api_key:
raise ValueError("Please provide a VoyageAI API key.")

try:
import voyageai
from voyageai import get_embeddings
except ImportError:
raise ValueError("The VoyageAI python package is not installed. Please install it with `pip install voyageai`")

voyageai.api_key = api_key # Voyage API Key
self.batch_size = batch_size
self.model = model_name
self.get_embeddings = get_embeddings

def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.
Args:
input (Documents): A list of texts to get embeddings for.
Returns:
Embeddings: The embeddings for the texts.
Example:
>>> voyage_ef = VoyageAIEmbeddingFunction(api_key="your_api_key")
>>> input = ["Hello, world!", "How are you?"]
>>> embeddings = voyage_ef(input)
"""
iters = range(0, len(input), self.batch_size)
embeddings = []
for i in iters:
results = self.get_embeddings(
input[i : i + self.batch_size],
batch_size=self.batch_size,
model=self.model
)
embeddings += results;
return embeddings;
"""Embedding function for Voyageai.com. API docs - https://docs.voyageai.com/reference/embeddings-api"""

def __init__(
self,
api_key: str,
model_name: str = "voyage-2",
max_batch_size: int = 128,
truncation: Optional[bool] = True,
input_type: Optional[str] = None,
):
"""
Initialize the VoyageAIEmbeddingFunction.
Args:
api_key (str): Your API key for the HuggingFace API.
model_name (str, optional): The name of the model to use for text embeddings. Defaults to "voyage-01".
batch_size (int, optional): The number of documents to send at a time. Defaults to 128 (The max supported 7th Apr 2024). see voyageai.VOYAGE_EMBED_BATCH_SIZE for actual max.
truncation (bool, optional): Whether to truncate the input (`True`) or raise an error if the input is too long (`False`). Defaults to `False`.
input_type (str, optional): The type of input text. Can be `None`, `query`, `document`. Defaults to `None`.
"""

if not api_key:
raise ValueError("Please provide a VoyageAI API key.")

try:
import voyageai

if max_batch_size > voyageai.VOYAGE_EMBED_BATCH_SIZE:
raise ValueError(
f"The maximum batch size supported is {voyageai.VOYAGE_EMBED_BATCH_SIZE}."
)
voyageai.api_key = api_key # Voyage API Key
self._batch_size = max_batch_size
self._model = model_name
self._truncation = truncation
self._client = voyageai.Client()
self._input_type = input_type
except ImportError:
raise ValueError(
"The VoyageAI python package is not installed. Please install it with `pip install voyageai`"
)

def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.
Args:
input (Documents): A list of texts to get embeddings for.
Returns:
Embeddings: The embeddings for the texts.
Example:
>>> voyage_ef = VoyageAIEmbeddingFunction(api_key="your_api_key")
>>> input = ["Hello, world!", "How are you?"]
>>> embeddings = voyage_ef(input)
"""
if len(input) > self._batch_size:
raise ValueError(f"The maximum batch size supported is {self._batch_size}.")
results = self._client.embed(
texts=input,
model=self._model,
truncation=self._truncation,
input_type=self._input_type,
)
return results.embeddings


def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore
Expand Down Expand Up @@ -1012,7 +1025,7 @@ def __call__(self, input: Documents) -> Embeddings: # type: ignore

return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn)


class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
Expand Down Expand Up @@ -1068,7 +1081,7 @@ def __call__(self, input: Documents) -> Embeddings:
],
)


# List of all classes in this module
_classes = [
name
Expand All @@ -1078,4 +1091,4 @@ def __call__(self, input: Documents) -> Embeddings:


def get_builtins() -> List[str]:
return _classes
return _classes

0 comments on commit c0cbbed

Please sign in to comment.