Skip to content

Commit

Permalink
Initial version
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Jan 29, 2025
1 parent 2332f63 commit 3e7ab54
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ MACRO: 'MACRO';
MAP: 'MAP' {incComplexTypeLevelCounter();};
MATCHED: 'MATCHED';
MERGE: 'MERGE';
MESSAGE_TEXT: 'MESSAGE_TEXT';
MICROSECOND: 'MICROSECOND';
MICROSECONDS: 'MICROSECONDS';
MILLISECOND: 'MILLISECOND';
Expand Down Expand Up @@ -410,6 +411,7 @@ SETMINUS: 'MINUS';
SETS: 'SETS';
SHORT: 'SHORT';
SHOW: 'SHOW';
SIGNAL: 'SIGNAL';
SINGLE: 'SINGLE';
SKEWED: 'SKEWED';
SMALLINT: 'SMALLINT';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ compoundStatement
| beginEndCompoundBlock
| declareConditionStatement
| declareHandlerStatement
| signalStatement
| ifElseStatement
| caseStatement
| whileStatement
Expand Down Expand Up @@ -104,6 +105,11 @@ declareHandlerStatement
: DECLARE (CONTINUE | EXIT) HANDLER FOR conditionValues (beginEndCompoundBlock | statement | setStatementWithOptionalVarKeyword)
;

signalStatement
: SIGNAL conditionName=multipartIdentifier (SET MESSAGE_TEXT EQ (msgStr=stringLit|msgVar=multipartIdentifier))? #signalStatementWithCondition
| SIGNAL SQLSTATE VALUE? sqlState=stringLit (SET MESSAGE_TEXT EQ (msgStr=stringLit|msgVar=multipartIdentifier))? #signalStatementWithSqlState
;

whileStatement
: beginLabel? WHILE booleanExpression DO compoundBody END WHILE endLabel?
;
Expand Down Expand Up @@ -1740,6 +1746,7 @@ ansiNonReserved
| MAP
| MATCHED
| MERGE
| MESSAGE_TEXT
| MICROSECOND
| MICROSECONDS
| MILLISECOND
Expand Down Expand Up @@ -1820,6 +1827,7 @@ ansiNonReserved
| SETS
| SHORT
| SHOW
| SIGNAL
| SINGLE
| SKEWED
| SMALLINT
Expand Down Expand Up @@ -2110,6 +2118,7 @@ nonReserved
| MAP
| MATCHED
| MERGE
| MESSAGE_TEXT
| MICROSECOND
| MICROSECONDS
| MILLISECOND
Expand Down Expand Up @@ -2201,6 +2210,7 @@ nonReserved
| SETS
| SHORT
| SHOW
| SIGNAL
| SINGLE
| SKEWED
| SMALLINT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,33 @@ class AstBuilder extends DataTypeAstBuilder
ExceptionHandler(exceptionHandlerTriggers, body, handlerType)
}

override def visitSignalStatementWithCondition(
ctx: SignalStatementWithConditionContext): SignalStatement = {
val messageString = Option(ctx.msgStr)
.map(sl => Left(string(visitStringLit(sl)).replace("'", "")))
val messageVariable = Option(ctx.msgVar)
.map(mpi => Right(UnresolvedAttribute(visitMultipartIdentifier(mpi))))

SignalStatement(
errorCondition = Some(ctx.conditionName.getText),
message = messageVariable.getOrElse(messageString.getOrElse(Left(""))))
}

override def visitSignalStatementWithSqlState(
ctx: SignalStatementWithSqlStateContext): SignalStatement = {
val sqlState = visitStringLit(ctx.sqlState).getText.replace("'", "")
assertSqlState(sqlState)

val messageString = Option(ctx.msgStr)
.map(sl => Left(string(visitStringLit(sl)).replace("'", "")))
val messageVariable = Option(ctx.msgVar)
.map(mpi => Right(UnresolvedAttribute(visitMultipartIdentifier(mpi))))

SignalStatement(
sqlState = Some(sqlState),
message = messageVariable.getOrElse(messageString.getOrElse(Left(""))))
}

private def visitCompoundBodyImpl(
ctx: CompoundBodyContext,
label: Option[String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute

import java.util.Locale

import scala.collection.mutable.{HashMap, Set}
Expand Down Expand Up @@ -404,3 +406,21 @@ case class ExceptionHandler(
handlerType)
}
}

/**
* Logical operator for Signal Statement.
* @param errorCondition Name of the error condition/SQL State for error that will be thrown.
* @param sqlState SQL State for error that will be thrown.
* @param message Error message (either string or variable name).
*/
case class SignalStatement(
errorCondition: Option[String] = None,
sqlState: Option[String] = None,
message: Either[String, UnresolvedAttribute]) extends CompoundPlanStatement {
override def output: Seq[Attribute] = Seq.empty

override def children: Seq[LogicalPlan] = Seq.empty

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = this.copy()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.exceptions

import scala.jdk.CollectionConverters._

import org.apache.spark.SparkThrowable
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.exceptions.SqlScriptingRuntimeException.errorMessageWithLineNumber

class SqlScriptingRuntimeException (
condition: Option[String] = None,
sqlState: Option[String] = None,
message: String,
cause: Throwable,
val origin: Origin,
messageParameters: Map[String, String] = Map.empty)
extends Exception(
errorMessageWithLineNumber(Option(origin), condition, message, messageParameters),
cause)
with SparkThrowable {

def getCondition: String = condition.getOrElse("USER_RAISED_EXCEPTION")

override def getErrorClass: String = getCondition

override def getSqlState: String = sqlState.getOrElse(super.getSqlState)

override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava
}

private object SqlScriptingRuntimeException {

private def errorMessageWithLineNumber(
origin: Option[Origin],
condition: Option[String] = None,
message: String,
messageParameters: Map[String, String]): String = {
val prefix = origin.flatMap(o => o.line.map(l => s"{LINE:$l} ")).getOrElse("")
prefix + message
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.sql.scripting
import org.apache.spark.SparkThrowable
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.classic.{DataFrame, SparkSession}
import org.apache.spark.sql.exceptions.SqlScriptingRuntimeException

/**
* SQL scripting executor - executes script and returns result statements.
Expand Down Expand Up @@ -92,6 +94,14 @@ class SqlScriptingExecution(
// While we don't have a result statement, execute the statements.
while (currentStatement.isDefined) {
currentStatement match {
case Some(signalStatementExec: SignalStatementExec) =>
throw new SqlScriptingRuntimeException(
condition = signalStatementExec.errorCondition,
sqlState = signalStatementExec.sqlState,
message = signalStatementExec.getMessageText,
cause = null,
origin = CurrentOrigin.get
)
case Some(stmt: SingleStatementExec) if !stmt.isExecuted =>
withErrorHandling {
val df = stmt.buildDataFrame(session)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ import java.util
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal}
import org.apache.spark.sql.catalyst.analysis.{ColumnResolutionHelper, NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal, VariableReference}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable}
import org.apache.spark.sql.catalyst.plans.logical.HandlerType.HandlerType
import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.errors.SqlScriptingErrors
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.errors.{QueryCompilationErrors, SqlScriptingErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, StringType}

/**
* Trait for all SQL scripting execution nodes used during interpretation phase.
Expand Down Expand Up @@ -1007,3 +1009,59 @@ class ErrorHandlerExec(

override def reset(): Unit = body.reset()
}

/**
* Executable node for Signal Statement.
* @param errorCondition Name of the error condition/SQL State for error that will be thrown.
* @param sqlState SQL State of the error that will be thrown.
* @param message Error message (either string or variable name).
* @param session Spark session that SQL script is executed within.
*/
class SignalStatementExec(
val errorCondition: Option[String] = None,
val sqlState: Option[String] = None,
val message: Either[String, UnresolvedAttribute],
val session: SparkSession)
extends LeafStatementExec
with ColumnResolutionHelper {

override def catalogManager: CatalogManager = session.sessionState.catalogManager
override def conf: SQLConf = session.sessionState.conf

def getMessageText: String = {
message match {
case Left(v) => v
case Right(u) =>
val varReference = getVariableReference(u, u.nameParts)

if (!varReference.dataType.sameType(StringType)) {
throw QueryCompilationErrors.invalidExecuteImmediateVariableType(varReference.dataType)
}

// Call eval with null value passed instead of a row.
// This is ok as this is variable and invoking eval should
// be independent of row value.
val varReferenceValue = varReference.eval(null)

if (varReferenceValue == null) {
throw QueryCompilationErrors.nullSQLStringExecuteImmediate(u.name)
}

varReferenceValue.toString
}
}

private def getVariableReference(expr: Expression, nameParts: Seq[String]): VariableReference = {
lookupVariable(nameParts) match {
case Some(variable) => variable
case _ =>
throw QueryCompilationErrors
.unresolvedVariableError(
nameParts,
Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE),
expr.origin)
}
}

override def reset(): Unit = ()
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ForStatement, HandlerType, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ForStatement, HandlerType, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SignalStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.errors.SqlScriptingErrors
Expand Down Expand Up @@ -274,6 +274,13 @@ case class SqlScriptingInterpreter(session: SparkSession) {
case iterateStatement: IterateStatement =>
new IterateStatementExec(iterateStatement.label)

case signalStatement: SignalStatement =>
new SignalStatementExec(
errorCondition = signalStatement.errorCondition,
sqlState = signalStatement.sqlState,
message = signalStatement.message,
session = session)

case sparkStatement: SingleStatement =>
new SingleStatementExec(
sparkStatement.parsedPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,30 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
parameters = Map("sqlState" -> "12345"))
}

test("SIGNAL") {
val sqlScript =
"""
|BEGIN
| DECLARE OR REPLACE flag INT = -1;
| BEGIN
| DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO
| BEGIN
| SELECT flag;
| SET VAR flag = 1;
| END;
|
| SIGNAL DIVIDE_BY_ZERO;
| END;
| SELECT flag;
|END
|""".stripMargin
val expected = Seq(
Seq(Row(-1)), // select flag
Seq(Row(1)) // select flag from the outer body
)
verifySqlScriptResult(sqlScript, expected)
}

test("Specific condition takes precedence over sqlState") {
val sqlScript =
"""
Expand Down

0 comments on commit 3e7ab54

Please sign in to comment.