-
Notifications
You must be signed in to change notification settings - Fork 561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create a base class for the models #1385
base: v1.0
Are you sure you want to change the base?
Conversation
That's smart. I am not a big fan of mixins, however, and I wonder if we should instead define at least one base class for all models, and also require that |
I agree that mixins aren't great. I like the idea of having a base class for all models as it would allow, on top of avoiding repetitions, to define/document the interface of the models somewhere in a very straightforward manner. If we want to keep the formatting logic outside of the model classes not to make them have too many methods, but not rely on inheritance either, maybe a solution could be to require the models to have a formatter attribute that would inherit from a base ModelFormatter class. |
4b38b60
to
6627b34
Compare
I don't have too much to add from a technical perspective, but I like the single inheritance format. Much easier to extend. I poked around and it doesn't seem like this has much of an impact on most Outlines users, right? Only people deep into the internals are going to be impacted, and even then minorly. |
@@ -134,7 +180,7 @@ def from_pretrained(cls, repo_id, filename, **kwargs): | |||
model = Llama.from_pretrained(repo_id, filename, **kwargs) | |||
return cls(model) | |||
|
|||
def generate(self, prompt: str, logits_processor, **inference_kwargs) -> str: | |||
def generate(self, model_input, logits_processor, **inference_kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_input
is still a string for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed it because I had a mypy error for inconsistent method signature with the method of the parent class. If we want to keep it we would need to remove generate
from the base class I think.
A few nitpicks, looks good otherwise. To fix the CI:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a few superficial comments, nice work!
6627b34
to
2b65abd
Compare
2b65abd
to
6b15598
Compare
6b15598
to
768c524
Compare
Updated on 2025-02-11
The objective of this PR is to create a base class all models would inherit from. This base class defines the interface for models and includes a
__call__
method that allows models to be called with a prompt and output type directly following the suggestion in #1359 .Formatting methods for the input and output types are implemented in a separate class inheriting from the base class
ModelTypeAdapter
. Each model must set atype_adapter
attribute in its initialization. The rationale for separating those from the model class is that the accepted types for each model and the formatting logic are quite different concerns compared to the specifics of the model's implementation. Separating them also allows us to have an easier to read documentation of the accepted types for each model.Another change proposed in this PR is to set a mandatory
model_type
attribute in each model class. This would allow us not to have to rely on a static list of models through theAPIModel
andLocalModel
classes.The use of the base model class is implemented for the models that have already been refactored: OpenAI, Gemini, Anthropic and Llamacpp.