diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index a3ce4f3239b60..2eb37fd65d7fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -387,6 +387,10 @@ class RFormulaModel private[feature]( s"RFormulaModel: uid=$uid, resolvedFormula=$resolvedFormula" } + // For ml connect only + @Since("4.0.0") + private[ml] def resolvedFormulaString: String = resolvedFormula.toString + private def transformLabel(dataset: Dataset[_]): DataFrame = { val labelName = resolvedFormula.label if (labelName.isEmpty || hasLabelCol(dataset.schema)) { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 81f6c7ebcbdf0..cc8f19aaf82c2 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -6841,7 +6841,7 @@ class RFormulaModel(JavaModel, _RFormulaParams, JavaMLReadable["RFormulaModel"], """ def __str__(self) -> str: - resolvedFormula = self._call_java("resolvedFormula") + resolvedFormula = self._call_java("resolvedFormulaString") return "RFormulaModel(%s) (uid=%s)" % (resolvedFormula, self.uid) diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 12926e3e5bb46..cc47832950799 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -1312,8 +1312,7 @@ def test_rformula_string_indexer_order_type(self): model.write().overwrite().save(d) model2 = RFormulaModel.load(d) - # TODO: fix str(model) - # self.assertEqual(str(model), str(model2)) + self.assertEqual(str(model), str(model2)) self.assertEqual(model.getFormula(), model2.getFormula()) def test_string_indexer_handle_invalid(self): diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index acf5abd938f74..6f59af46d0a7a 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -648,6 +648,7 @@ private[ml] object MLUtils { (classOf[CountVectorizerModel], Set("vocabulary")), (classOf[OneHotEncoderModel], Set("categorySizes")), (classOf[StringIndexerModel], Set("labels", "labelsArray")), + (classOf[RFormulaModel], Set("resolvedFormulaString")), (classOf[IDFModel], Set("idf", "docFreq", "numDocs"))) private def validate(obj: Any, method: String): Unit = {