Skip to content

Commit

Permalink
[SPARK-51015][ML][PYTHON][CONNECT] Support RFormulaModel.toString on …
Browse files Browse the repository at this point in the history
…Connect

### What changes were proposed in this pull request?

This PR adds support toString for RFormulaModel on ml Connect.

### Why are the changes needed?

Feature parity

### Does this PR introduce _any_ user-facing change?
Yes

### How was this patch tested?
CI passes

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #49745 from wbo4958/rf.tostring.

Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit f2d65ee)
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
wbo4958 authored and zhengruifeng committed Feb 2, 2025
1 parent 7100598 commit 7cedc87
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ class RFormulaModel private[feature](
s"RFormulaModel: uid=$uid, resolvedFormula=$resolvedFormula"
}

// For ml connect only
@Since("4.0.0")
private[ml] def resolvedFormulaString: String = resolvedFormula.toString

private def transformLabel(dataset: Dataset[_]): DataFrame = {
val labelName = resolvedFormula.label
if (labelName.isEmpty || hasLabelCol(dataset.schema)) {
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6841,7 +6841,7 @@ class RFormulaModel(JavaModel, _RFormulaParams, JavaMLReadable["RFormulaModel"],
"""

def __str__(self) -> str:
resolvedFormula = self._call_java("resolvedFormula")
resolvedFormula = self._call_java("resolvedFormulaString")
return "RFormulaModel(%s) (uid=%s)" % (resolvedFormula, self.uid)


Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,8 +1312,7 @@ def test_rformula_string_indexer_order_type(self):

model.write().overwrite().save(d)
model2 = RFormulaModel.load(d)
# TODO: fix str(model)
# self.assertEqual(str(model), str(model2))
self.assertEqual(str(model), str(model2))
self.assertEqual(model.getFormula(), model2.getFormula())

def test_string_indexer_handle_invalid(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ private[ml] object MLUtils {
(classOf[CountVectorizerModel], Set("vocabulary")),
(classOf[OneHotEncoderModel], Set("categorySizes")),
(classOf[StringIndexerModel], Set("labels", "labelsArray")),
(classOf[RFormulaModel], Set("resolvedFormulaString")),
(classOf[IDFModel], Set("idf", "docFreq", "numDocs")))

private def validate(obj: Any, method: String): Unit = {
Expand Down

0 comments on commit 7cedc87

Please sign in to comment.