1515
1616package za .co .absa .hyperdrive .trigger .scheduler .executors .spark
1717
18- import org .apache .spark .launcher .{
19- InProcessLauncher ,
20- NoBackendConnectionInProcessLauncher ,
21- SparkAppHandle ,
22- SparkLauncher
23- }
18+ import org .apache .hadoop .security .UserGroupInformation
19+ import org .apache .spark .launcher .{InProcessLauncher , NoBackendConnectionInProcessLauncher , SparkAppHandle }
2420import org .springframework .stereotype .Service
2521import za .co .absa .hyperdrive .trigger .configuration .application .SparkConfig
2622import za .co .absa .hyperdrive .trigger .models .enums .JobStatuses .{Lost , SubmissionTimeout , Submitting }
2723import za .co .absa .hyperdrive .trigger .models .{JobInstance , SparkInstanceParameters }
2824import za .co .absa .hyperdrive .trigger .api .rest .utils .Extensions ._
2925
26+ import java .security .PrivilegedExceptionAction
3027import java .util .UUID .randomUUID
3128import java .util .concurrent .{CountDownLatch , TimeUnit }
3229import javax .inject .Inject
@@ -38,6 +35,8 @@ class SparkYarnClusterServiceImpl @Inject() (
3835 executionContextProvider : SparkClusterServiceExecutionContextProvider
3936) extends SparkClusterService {
4037 private implicit val executionContext : ExecutionContext = executionContextProvider.get()
38+ private val SparkYarnPrincipalProp = " spark.yarn.principal"
39+ private val SparkYarnKeytabProp = " spark.yarn.keytab"
4140
4241 override def submitJob (
4342 jobInstance : JobInstance ,
@@ -49,17 +48,18 @@ class SparkYarnClusterServiceImpl @Inject() (
4948 updateJob(ji).map { _ =>
5049 val submitTimeout = sparkConfig.yarn.submitTimeout
5150 val latch = new CountDownLatch (1 )
52- val sparkAppHandle =
53- getSparkLauncher(id, ji.jobName, jobParameters).startApplication(new SparkAppHandle .Listener {
54- import scala .math .Ordered .orderingToOrdered
55- override def stateChanged (handle : SparkAppHandle ): Unit =
56- if (handle.getState >= SparkAppHandle .State .SUBMITTED ) {
57- latch.countDown()
58- }
59- override def infoChanged (handle : SparkAppHandle ): Unit = {
60- // do nothing
51+ val sparkAppHandleListener = new SparkAppHandle .Listener {
52+ import scala .math .Ordered .orderingToOrdered
53+ override def stateChanged (handle : SparkAppHandle ): Unit =
54+ if (handle.getState >= SparkAppHandle .State .SUBMITTED ) {
55+ latch.countDown()
6156 }
62- })
57+ override def infoChanged (handle : SparkAppHandle ): Unit = {
58+ // do nothing
59+ }
60+ }
61+ val sparkAppHandle =
62+ startSparkJob(getSparkLauncher(id, ji.jobName, jobParameters), sparkAppHandleListener, jobParameters)
6363 latch.await(submitTimeout, TimeUnit .MILLISECONDS )
6464 sparkAppHandle.kill()
6565 }
@@ -103,6 +103,24 @@ class SparkYarnClusterServiceImpl @Inject() (
103103 sparkLauncher
104104 }
105105
106+ private def startSparkJob (inProcessLauncher : InProcessLauncher ,
107+ sparkAppHandleListener : SparkAppHandle .Listener ,
108+ jobParameters : SparkInstanceParameters
109+ ): SparkAppHandle = {
110+ val user = jobParameters.additionalSparkConfig.find(_.key == SparkYarnPrincipalProp ).map(_.value)
111+ val keytab = jobParameters.additionalSparkConfig.find(_.key == SparkYarnKeytabProp ).map(_.value)
112+ (user, keytab) match {
113+ case (Some (u), Some (k)) =>
114+ val ugi = UserGroupInformation .loginUserFromKeytabAndReturnUGI(u, k)
115+ ugi.doAs(new PrivilegedExceptionAction [SparkAppHandle ]() {
116+ override def run (): SparkAppHandle = {
117+ inProcessLauncher.startApplication(sparkAppHandleListener)
118+ }
119+ })
120+ case _ => inProcessLauncher.startApplication(sparkAppHandleListener)
121+ }
122+ }
123+
106124 /*
107125 Fixed inspired by https://stackoverflow.com/questions/43040793/scala-via-spark-with-yarn-curly-brackets-string-missing
108126 See https://issues.apache.org/jira/browse/SPARK-17814
0 commit comments