From 2e28f1ffa74ec8224e665275211ab509ed4f8758 Mon Sep 17 00:00:00 2001 From: Jitendra Patil Date: Fri, 6 Jan 2017 19:24:53 -0800 Subject: [PATCH] DPNG-12495: revise model feature on 0.7.4 (#9) * revise model feature on 0.7.4 * cleanup --- .../DataOutputFormatJsonProtocol.scala | 179 ++++++++++++++++++ .../scoring/ScoringEngineHelper.scala | 158 ++++++++++++++++ .../scoring/ScoringService.scala | 167 +++++++++++----- .../scoring/ScoringServiceApplication.scala | 124 ++---------- .../scoring/ScoringServiceJsonProtocol.scala | 165 +--------------- .../ScoringServiceJsonProtocolTest.scala | 8 +- 6 files changed, 476 insertions(+), 325 deletions(-) create mode 100644 src/main/scala/org/trustedanalytics/scoring/DataOutputFormatJsonProtocol.scala create mode 100644 src/main/scala/org/trustedanalytics/scoring/ScoringEngineHelper.scala diff --git a/src/main/scala/org/trustedanalytics/scoring/DataOutputFormatJsonProtocol.scala b/src/main/scala/org/trustedanalytics/scoring/DataOutputFormatJsonProtocol.scala new file mode 100644 index 0000000..1b76059 --- /dev/null +++ b/src/main/scala/org/trustedanalytics/scoring/DataOutputFormatJsonProtocol.scala @@ -0,0 +1,179 @@ +package org.trustedanalytics.scoring + +import org.joda.time.DateTime +import org.trustedanalytics.scoring.interfaces.Model +import spray.json.{JsString, _} + +import scala.collection.immutable.Map +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe._ +import spray.json.DefaultJsonProtocol._ + +class DataOutputFormatJsonProtocol(model: Model) { + + implicit object DataInputFormat extends JsonFormat[Seq[Array[Any]]] { + + //don't need this method. just there to satisfy the API. + override def write(obj: Seq[Array[Any]]): JsValue = ??? + + override def read(json: JsValue): Seq[Array[Any]] = { + val records = json.asJsObject.getFields("records") match { + case Seq(JsArray(records)) => records + case x => deserializationError(s"Expected array of records but got $x") + } + decodeRecords(records) + } + } + + implicit object DataOutputFormat extends JsonFormat[Array[Map[String, Any]]] { + + override def write(obj: Array[Map[String, Any]]): JsValue = { + val modelMetadata = model.modelMetadata() + JsObject("data" -> new JsArray(obj.map(output => DataTypeJsonFormat.write(output)).toList)) + } + + //don't need this method. just there to satisfy the API. + override def read(json: JsValue): Array[Map[String, Any]] = ??? + } + + def decodeRecords(records: List[JsValue]): Seq[Array[Any]] = { + val decodedRecords: Seq[Map[String, Any]] = records.map { record => + record match { + case JsObject(fields) => + val decodedRecord: Map[String, Any] = for ((feature, value) <- fields) yield (feature, decodeJValue(value)) + decodedRecord + } + } + var features: Seq[Array[Any]] = Seq[Array[Any]]() + decodedRecords.foreach(decodedRecord => { + val obsColumns = model.input() + val featureArray = new Array[Any](obsColumns.length) + if (decodedRecord.size != featureArray.length) { + throw new IllegalArgumentException( + "Size of input record is not equal to number of observation columns that model was trained on:\n" + + s"""Expected columns are: [${obsColumns.mkString(",")}]""" + ) + } + decodedRecord.foreach({ + case (name, value) => { + var counter = 0 + var found = false + while (counter < obsColumns.length && !found) { + if (obsColumns(counter).name != name) { + counter = counter + 1 + } + else { + featureArray(counter) = value + found = true + } + } + if (!found) { + throw new IllegalArgumentException( + s"""$name was not found in list of Observation Columns that model was trained on: [${obsColumns.mkString(",")}]""" + ) + } + + } + }) + features = features :+ featureArray + }) + features + } + + + def decodeJValue(v: JsValue): Any = { + v match { + case JsString(s) => s + case JsNumber(n) => n.toDouble + case JsArray(items) => for (item <- items) yield decodeJValue(item) + case JsNull => null + case JsObject(fields) => + val decodedValue: Map[String, Any] = for ((feature, value) <- fields) yield (feature, decodeJValue(value)) + decodedValue + case x => deserializationError(s"Unexpected JSON type in record $x") + } + } + + private def mapToJson[K <: Any, V <: Any](m: Map[K, V]): JsObject = { + require(m != null, s"Scoring service cannot serialize null to JSON") + val jsMap: Map[String, JsValue] = m.map { + case (x) => x match { + case (k: String, n: Double) => (k, n.toJson) + case (k: String, n: Int) => (k, n.toJson) + case (k: String, n: Long) => (k, n.toJson) + case (k: String, n: Float) => (k, n.toJson) + case (k: String, str: String) => (k, JsString(str)) + case (k: String, list: List[_]) => (k, listToJson(list)) + case (k: String, array: Array[_]) => (k, listToJson(array.toList)) + case (k: String, vector: Vector[_]) => (k, listToJson(vector.toList)) + case unk => serializationError(s"Scoring service cannot serialize ${unk.getClass.getName} to JSON") + } + } + JsObject(jsMap) + } + + implicit object DataTypeJsonFormat extends JsonFormat[Any] { + override def write(obj: Any): JsValue = { + obj match { + case n: Int => new JsNumber(n) + case n: Long => new JsNumber(n) + case n: Float => new JsNumber(BigDecimal(n)) + case n: Double => new JsNumber(n) + case s: String => new JsString(s) + case s: Boolean => JsBoolean(s) + case dt: DateTime => JsString(org.joda.time.format.ISODateTimeFormat.dateTime.print(dt)) + case m: Map[_, _] @unchecked => mapToJson(m) + case v: List[_] => listToJson(v) + case v: Array[_] => listToJson(v.toList) + case v: Vector[_] => listToJson(v.toList) + case v: ArrayBuffer[_] @unchecked => listToJson(v.toList) + case n: java.lang.Long => new JsNumber(n.longValue()) + // case null => JsNull Consciously not writing nulls, may need to change, but for now it may catch bugs + case unk => + val name: String = if (unk != null) { + unk.getClass.getName + } + else { + "null" + } + serializationError("Cannot serialize " + name) + } + } + + override def read(json: JsValue): Any = { + json match { + case JsNumber(n) if n.isValidInt => n.intValue() + case JsNumber(n) if n.isValidLong => n.longValue() + case JsNumber(n) if n.isValidFloat => n.floatValue() + case JsNumber(n) => n.doubleValue() + case JsBoolean(b) => b + case JsString(s) => s + case JsArray(v) => v.map(x => read(x)) + case obj: JsObject => obj.fields.map { + case (a, JsArray(v)) => (a, v.map(x => read(x))) + case (a, JsNumber(b)) => (a, b) + } + case unk => deserializationError("Cannot deserialize " + unk.getClass.getName) + } + } + } + + private def listToJson(list: List[Any]): JsArray = { + require(list != null, s"Scoring service cannot serialize null to JSON") + val jsElements = list.map { + case n: Double => n.toJson + case n: Int => n.toJson + case n: Long => n.toJson + case n: Float => n.toJson + case str: String => str.toJson + case map: Map[_, _] @unchecked => mapToJson(map) + case list: List[_] => listToJson(list) + case arr: Array[_] => listToJson(arr.toList) + case vector: Vector[_] => listToJson(vector.toList) + case unk => serializationError(s"Scoring service cannot serialize ${unk.getClass.getName} to Json") + } + new JsArray(jsElements) + } + + +} diff --git a/src/main/scala/org/trustedanalytics/scoring/ScoringEngineHelper.scala b/src/main/scala/org/trustedanalytics/scoring/ScoringEngineHelper.scala new file mode 100644 index 0000000..bf9e266 --- /dev/null +++ b/src/main/scala/org/trustedanalytics/scoring/ScoringEngineHelper.scala @@ -0,0 +1,158 @@ +/** + * Copyright (c) 2015 Intel Corporation  + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + *       http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.trustedanalytics.scoring + +import java.io.{ FileOutputStream, File, FileInputStream } + +import akka.actor.{ ActorSystem, Props } +import akka.io.IO +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{ FileSystem, Path } +import org.trustedanalytics.hadoop.config.client.oauth.TapOauthToken +import spray.can.Http +import akka.pattern.ask +import akka.util.Timeout +import scala.concurrent.duration._ +import com.typesafe.config.{ Config, ConfigFactory } +import scala.reflect.ClassTag +import org.trustedanalytics.scoring.interfaces.Model +import org.trustedanalytics.hadoop.config.client.helper.Hdfs +import java.net.URI +import org.apache.commons.io.FileUtils + +import org.apache.http.message.BasicNameValuePair +import org.apache.http.auth.{ AuthScope, UsernamePasswordCredentials } +import org.apache.http.impl.client.{ BasicCredentialsProvider, HttpClientBuilder } +import org.apache.http.client.methods.{ HttpPost, CloseableHttpResponse } +import org.apache.http.HttpHost +import org.apache.http.client.config.RequestConfig +import org.apache.http.client.entity.UrlEncodedFormEntity +import java.util.{ ArrayList => JArrayList } +import spray.json._ +import org.trustedanalytics.model.archive.format.ModelArchiveFormat +import org.slf4j.LoggerFactory + +object ScoringEngineHelper { + private val logger = LoggerFactory.getLogger(this.getClass) + val config = ConfigFactory.load(this.getClass.getClassLoader) + /** + * + * @param model Original model + * @param revisedModel Revised model + * @return true if model compatible i.e model type is same and model input/output parameters are same + * else returns false + */ + def isModelCompatible(model: Model, revisedModel: Model): Boolean = { + model.modelMetadata().modelType == revisedModel.modelMetadata().modelType && + model.input().deep == revisedModel.input().deep && + model.output().deep == revisedModel.output().deep + } + + def getModel(modelFilePath: String): Model = { + var tempMarFile: File = null + var marFilePath = modelFilePath + try { + if (marFilePath.startsWith("hdfs:/")) { + if (!marFilePath.startsWith("hdfs://")) { + val relativePath = marFilePath.substring(marFilePath.indexOf("hdfs:") + 6) + marFilePath = "hdfs://" + relativePath + } + + val hdfsFileSystem = try { + val token = new TapOauthToken(getJwtToken()) + logger.info(s"Successfully retreived a token for user ${token.getUserName}") + Hdfs.newInstance().createFileSystem(token) + } + catch { + case t: Throwable => + t.printStackTrace() + logger.info("Failed to create HDFS instance using hadoop-library. Default to FileSystem") + org.apache.hadoop.fs.FileSystem.get(new URI(marFilePath), new Configuration()) + } + tempMarFile = File.createTempFile("model", ".mar") + hdfsFileSystem.copyToLocalFile(false, new Path(marFilePath), new Path(tempMarFile.getAbsolutePath)) + marFilePath = tempMarFile.getAbsolutePath + } + logger.info("calling ModelArchiveFormat to get the model") + sys.addShutdownHook(FileUtils.deleteQuietly(tempMarFile)) // Delete temporary directory on exit + ModelArchiveFormat.read(new File(marFilePath), this.getClass.getClassLoader, None) + } + finally { + FileUtils.deleteQuietly(tempMarFile) + } + } + + def getJwtToken(): String = { + + val query = s"http://${System.getenv("UAA_URI")}/oauth/token" + val headers = List(("Accept", "application/json")) + val data = List(("username", System.getenv("FS_TECHNICAL_USER_NAME")), ("password", System.getenv("FS_TECHNICAL_USER_PASSWORD")), ("grant_type", "password")) + val credentialsProvider = new BasicCredentialsProvider() + credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(System.getenv("UAA_CLIENT_NAME"), System.getenv("UAA_CLIENT_PASSWORD"))) + + // TODO: This method uses Apache HttpComponents HttpClient as spray-http library does not support proxy over https + val (proxyHostConfigString, proxyPortConfigString) = ("https.proxyHost", "https.proxyPort") + val httpClient = HttpClientBuilder.create().setDefaultCredentialsProvider(credentialsProvider).build() + try { + val proxy = (sys.props.contains(proxyHostConfigString), sys.props.contains(proxyPortConfigString)) match { + case (true, true) => Some(new HttpHost(sys.props.get(proxyHostConfigString).get, sys.props.get(proxyPortConfigString).get.toInt)) + case _ => None + } + + val config = { + val cfg = RequestConfig.custom().setConnectTimeout(30) + if (proxy.isDefined) + cfg.setProxy(proxy.get).build() + else cfg.build() + } + + val request = new HttpPost(query) + val nvps = new JArrayList[BasicNameValuePair] + data.foreach { case (k, v) => nvps.add(new BasicNameValuePair(k, v)) } + request.setEntity(new UrlEncodedFormEntity(nvps)) + + for ((headerTag, headerData) <- headers) + request.addHeader(headerTag, headerData) + request.setConfig(config) + + var response: Option[CloseableHttpResponse] = None + try { + response = Some(httpClient.execute(request)) + val inputStream = response.get.getEntity().getContent + val result = scala.io.Source.fromInputStream(inputStream).getLines().mkString("\n") + logger.info(s"Response From UAA Server is $result") + result.parseJson.asJsObject().getFields("access_token") match { + case values => values(0).asInstanceOf[JsString].value + } + } + catch { + case ex: Throwable => + error(s"Error executing request ${ex.getMessage}") + // We need this exception to be thrown as this is a generic http request method and let caller handle. + throw ex + } + finally { + if (response.isDefined) + response.get.close() + } + } + finally { + httpClient.close() + } + } + +} diff --git a/src/main/scala/org/trustedanalytics/scoring/ScoringService.scala b/src/main/scala/org/trustedanalytics/scoring/ScoringService.scala index b7995c4..32c109f 100644 --- a/src/main/scala/org/trustedanalytics/scoring/ScoringService.scala +++ b/src/main/scala/org/trustedanalytics/scoring/ScoringService.scala @@ -24,9 +24,12 @@ import akka.event.Logging import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import ExecutionContext.Implicits.global -import scala.util.{ Failure, Success } +import scala.util.{Failure, Success} import org.trustedanalytics.scoring.interfaces.Model import spray.json._ +import DefaultJsonProtocol._ +import org.trustedanalytics.scoring.ScoringServiceJsonProtocol._ + /** * We don't implement our route structure directly in the service actor because @@ -54,18 +57,7 @@ class ScoringServiceActor(val scoringService: ScoringService) extends Actor with /** * Defines our service behavior independently from the service actor */ -class ScoringService(model: Model) extends Directives { - def homepage = { - respondWithMediaType(`text/html`) { - complete { - - -

Welcome to the New Scoring Engine

- - - } - } - } +class ScoringService(scoringModel: Model) extends Directives { lazy val description = { new ServiceDescription(name = "Trusted Analytics", @@ -73,11 +65,14 @@ class ScoringService(model: Model) extends Directives { versions = List("v1", "v2")) } - val jsonFormat = new ScoringServiceJsonProtocol(model) - import jsonFormat._ - import spray.json._ + case class ModelData(model: Model, jsonFormat: DataOutputFormatJsonProtocol) + + var modelData = ModelData(scoringModel, + new DataOutputFormatJsonProtocol(scoringModel)) + + /** * Main Route entry point to the Scoring Server */ @@ -86,7 +81,11 @@ class ScoringService(model: Model) extends Directives { val metadataPrefix = "metadata" path("") { get { - homepage + respondWithMediaType(`text/html`) { + complete( + getHomePage(modelData) + ) + } } } ~ path("v2" / prefix) { @@ -95,13 +94,7 @@ class ScoringService(model: Model) extends Directives { entity(as[String]) { scoreArgs => val json: JsValue = scoreArgs.parseJson - import jsonFormat.DataOutputFormat - onComplete(Future { scoreJsonRequest(DataInputFormat.read(json)) }) { - case Success(output) => complete(DataOutputFormat.write(output).toString()) - case Failure(ex) => ctx => { - ctx.complete(StatusCodes.InternalServerError, ex.getMessage) - } - } + getScore(modelData, json) } } } @@ -116,50 +109,136 @@ class ScoringService(model: Model) extends Directives { val splitSegment = decoded.split(",") records = records :+ splitSegment.asInstanceOf[Array[Any]] } - onComplete(Future { scoreStringRequest(records) }) { - case Success(string) => complete(string.mkString(",")) - case Failure(ex) => ctx => { - ctx.complete(StatusCodes.InternalServerError, ex.getMessage) - } - } + getScoreV1(modelData, records) } } } ~ path("v2" / metadataPrefix) { requestUri { uri => get { - import spray.json._ - onComplete(Future { model.modelMetadata() }) { - case Success(metadata) => complete(JsObject("model_details" -> metadata.toJson, - "input" -> new JsArray(model.input.map(input => FieldFormat.write(input)).toList), - "output" -> new JsArray(model.output.map(output => FieldFormat.write(output)).toList)).toString) - case Failure(ex) => ctx => { - ctx.complete(StatusCodes.InternalServerError, ex.getMessage) - - } + getMetaData(modelData) + } + } + }~ + path("v2" / "revise") { + requestUri { uri => + post { + entity(as[String]) { + args => + this.synchronized { + val path = if (args.parseJson.asJsObject.getFields("model-path").size == 1) { + args.parseJson.asJsObject.getFields("model-path")(0).convertTo[String] + } + else { + null + } + //if request data contains "force = true" , then force switch should be true, else false + val force = if (args.parseJson.asJsObject.getFields("force").size == 1) { + if (args.parseJson.asJsObject.getFields("force")(0).convertTo[String].toLowerCase == "true") true else false + } + else { + false + } + reviseModelData(modelData, path, force) + } } } } } } - def scoreStringRequest(records: Seq[Array[Any]]): Array[Any] = { + def scoreStringRequest(modelData: ModelData, records: Seq[Array[Any]]): Array[Any] = { records.map(row => { - val score = model.score(row) + val score = modelData.model.score(row) score(score.length - 1).toString }).toArray } - def scoreJsonRequest(records: Seq[Array[Any]]): Array[Map[String, Any]] = { - records.map(row => scoreToMap(model.score(row))).toArray + def scoreJsonRequest(modeldata: ModelData, json: JsValue): Array[Map[String, Any]] = { + val records = modeldata.jsonFormat.DataInputFormat.read(json) + records.map(row => scoreToMap(modeldata.model, modeldata.model.score(row))).toArray } - def scoreToMap(score: Array[Any]): Map[String, Any] = { + def scoreToMap(model: Model, score: Array[Any]): Map[String, Any] = { val outputNames = model.output().map(o => o.name) require(score.length == outputNames.length, "Length of output values should match the output names") val outputMap: Map[String, Any] = outputNames.zip(score).map(combined => (combined._1.name, combined._2)).toMap outputMap } + + private def getScore(md: ModelData, json: JsValue): Route = { + onComplete(Future { scoreJsonRequest(md, json) }) { + case Success(output) => complete(md.jsonFormat.DataOutputFormat.write(output).toString()) + case Failure(ex) => ctx => { + ctx.complete(StatusCodes.InternalServerError, ex.getMessage) + } + } + } + + private def getScoreV1(md: ModelData, records: Seq[Array[Any]]): Route = { + onComplete(Future { scoreStringRequest(md, records) }) { + case Success(string) => complete(string.mkString(",")) + case Failure(ex) => ctx => { + ctx.complete(StatusCodes.InternalServerError, ex.getMessage) + } + } + } + + private def getMetaData(md: ModelData): Route = { + import spray.json._ + onComplete(Future { md.model.modelMetadata() }) { + case Success(metadata) => complete(JsObject("model_details" -> metadata.toJson, + "input" -> new JsArray(md.model.input.map(input => FieldFormat.write(input)).toList), + "output" -> new JsArray(md.model.output.map(output => FieldFormat.write(output)).toList)).toString()) + case Failure(ex) => ctx => { + ctx.complete(StatusCodes.InternalServerError, ex.getMessage) + } + } + } + private def reviseModelData(md: ModelData, modelPath: String, force: Boolean = false): Route = { + if (modelPath == null) { + complete(StatusCodes.BadRequest, "'model-path' is not present in request!") + } + else { + try { + val revisedModel = ScoringEngineHelper.getModel(modelPath) + if (force || ScoringEngineHelper.isModelCompatible(modelData.model, revisedModel)) { + modelData = ModelData(revisedModel, new DataOutputFormatJsonProtocol(revisedModel)) + complete { """{"status": "success"}""" } + } + else { + complete(StatusCodes.BadRequest, "Revised Model type or input-output parameters names are " + + "different than existing model") + } + } + catch { + case e: Throwable => + modelData = md + e.printStackTrace() + if (e.getMessage.contains("File does not exist:")) { + complete(StatusCodes.BadRequest, e.getMessage) + } + else { + complete(StatusCodes.InternalServerError, e.getMessage) + } + } + } + } + + private def getHomePage(md: ModelData): String = { + val metadata = JsObject("model_details" -> md.model.modelMetadata().toJson, + "input" -> new JsArray(md.model.input.map(input => FieldFormat.write(input)).toList), + "output" -> new JsArray(md.model.output.map(output => FieldFormat.write(output)).toList)).prettyPrint + + s""" + + +

Welcome to the Scoring Engine

+

Model details:

+ Model metadata:
 $metadata 
+ + """ + } } case class ServiceDescription(name: String, identifier: String, versions: List[String]) diff --git a/src/main/scala/org/trustedanalytics/scoring/ScoringServiceApplication.scala b/src/main/scala/org/trustedanalytics/scoring/ScoringServiceApplication.scala index 9c47dff..77f944c 100644 --- a/src/main/scala/org/trustedanalytics/scoring/ScoringServiceApplication.scala +++ b/src/main/scala/org/trustedanalytics/scoring/ScoringServiceApplication.scala @@ -15,35 +15,18 @@ */ package org.trustedanalytics.scoring -import java.io.{ FileOutputStream, File, FileInputStream } +import java.util.{ArrayList => JArrayList} -import akka.actor.{ ActorSystem, Props } +import akka.actor.{ActorSystem, Props} import akka.io.IO -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{ FileSystem, Path } -import org.trustedanalytics.hadoop.config.client.oauth.TapOauthToken -import spray.can.Http import akka.pattern.ask import akka.util.Timeout -import scala.concurrent.duration._ -import com.typesafe.config.{ Config, ConfigFactory } -import scala.reflect.ClassTag +import com.typesafe.config.ConfigFactory +import org.slf4j.LoggerFactory import org.trustedanalytics.scoring.interfaces.Model -import org.trustedanalytics.hadoop.config.client.helper.Hdfs -import java.net.URI -import org.apache.commons.io.FileUtils +import spray.can.Http -import org.apache.http.message.BasicNameValuePair -import org.apache.http.auth.{ AuthScope, UsernamePasswordCredentials } -import org.apache.http.impl.client.{ BasicCredentialsProvider, HttpClientBuilder } -import org.apache.http.client.methods.{ HttpPost, CloseableHttpResponse } -import org.apache.http.HttpHost -import org.apache.http.client.config.RequestConfig -import org.apache.http.client.entity.UrlEncodedFormEntity -import java.util.{ ArrayList => JArrayList } -import spray.json._ -import org.trustedanalytics.model.archive.format.ModelArchiveFormat -import org.slf4j.LoggerFactory +import scala.concurrent.duration._ /** * Scoring Service Application - a REST application used by client layer to communicate with the Model. @@ -73,40 +56,13 @@ class ScoringServiceApplication { * load the model saved at the given path * @return Model running inside the scoring engine instance */ - private def getModel: Model = { - var tempMarFile: File = null - try { - var marFilePath = config.getString("trustedanalytics.scoring-engine.archive-mar") - if (marFilePath.startsWith("hdfs:/")) { - if (!marFilePath.startsWith("hdfs://")) { - val relativePath = marFilePath.substring(marFilePath.indexOf("hdfs:") + 6) - marFilePath = "hdfs://" + relativePath - } - - val hdfsFileSystem = try { - val token = new TapOauthToken(getJwtToken()) - logger.info(s"Successfully retreived a token for user ${token.getUserName}") - Hdfs.newInstance().createFileSystem(token) - } - catch { - case t: Throwable => - t.printStackTrace() - logger.info("Failed to create HDFS instance using hadoop-library. Default to FileSystem") - org.apache.hadoop.fs.FileSystem.get(new URI(marFilePath), new Configuration()) - } - tempMarFile = File.createTempFile("model", ".mar") - hdfsFileSystem.copyToLocalFile(false, new Path(marFilePath), new Path(tempMarFile.getAbsolutePath)) - marFilePath = tempMarFile.getAbsolutePath - } - logger.info("calling ModelArchiveFormat to get the model") - sys.addShutdownHook(FileUtils.deleteQuietly(tempMarFile)) // Delete temporary directory on exit - ModelArchiveFormat.read(new File(marFilePath), this.getClass.getClassLoader, None) - } - finally { - FileUtils.deleteQuietly(tempMarFile) - } + private def getModel(): Model = { + val marFilePath = config.getString("trustedanalytics.scoring-engine.archive-mar") + ScoringEngineHelper.getModel(marFilePath) } + + /** * We need an ActorSystem to host our application in and to bind it to an HTTP port */ @@ -129,64 +85,6 @@ class ScoringServiceApplication { } } - def getJwtToken(): String = { - - val query = s"http://${System.getenv("UAA_URI")}/oauth/token" - val headers = List(("Accept", "application/json")) - val data = List(("username", System.getenv("FS_TECHNICAL_USER_NAME")), ("password", System.getenv("FS_TECHNICAL_USER_PASSWORD")), ("grant_type", "password")) - val credentialsProvider = new BasicCredentialsProvider() - credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(System.getenv("UAA_CLIENT_NAME"), System.getenv("UAA_CLIENT_PASSWORD"))) - - // TODO: This method uses Apache HttpComponents HttpClient as spray-http library does not support proxy over https - val (proxyHostConfigString, proxyPortConfigString) = ("https.proxyHost", "https.proxyPort") - val httpClient = HttpClientBuilder.create().setDefaultCredentialsProvider(credentialsProvider).build() - try { - val proxy = (sys.props.contains(proxyHostConfigString), sys.props.contains(proxyPortConfigString)) match { - case (true, true) => Some(new HttpHost(sys.props.get(proxyHostConfigString).get, sys.props.get(proxyPortConfigString).get.toInt)) - case _ => None - } - - val config = { - val cfg = RequestConfig.custom().setConnectTimeout(30) - if (proxy.isDefined) - cfg.setProxy(proxy.get).build() - else cfg.build() - } - - val request = new HttpPost(query) - val nvps = new JArrayList[BasicNameValuePair] - data.foreach { case (k, v) => nvps.add(new BasicNameValuePair(k, v)) } - request.setEntity(new UrlEncodedFormEntity(nvps)) - - for ((headerTag, headerData) <- headers) - request.addHeader(headerTag, headerData) - request.setConfig(config) - - var response: Option[CloseableHttpResponse] = None - try { - response = Some(httpClient.execute(request)) - val inputStream = response.get.getEntity().getContent - val result = scala.io.Source.fromInputStream(inputStream).getLines().mkString("\n") - logger.info(s"Response From UAA Server is $result") - result.parseJson.asJsObject().getFields("access_token") match { - case values => values(0).asInstanceOf[JsString].value - } - } - catch { - case ex: Throwable => - error(s"Error executing request ${ex.getMessage}") - // We need this exception to be thrown as this is a generic http request method and let caller handle. - throw ex - } - finally { - if (response.isDefined) - response.get.close() - } - } - finally { - httpClient.close() - } - } } diff --git a/src/main/scala/org/trustedanalytics/scoring/ScoringServiceJsonProtocol.scala b/src/main/scala/org/trustedanalytics/scoring/ScoringServiceJsonProtocol.scala index e7abf16..e12f1f8 100644 --- a/src/main/scala/org/trustedanalytics/scoring/ScoringServiceJsonProtocol.scala +++ b/src/main/scala/org/trustedanalytics/scoring/ScoringServiceJsonProtocol.scala @@ -24,7 +24,7 @@ import org.trustedanalytics.scoring.interfaces.{ Model, Field, ModelMetaData } import scala.reflect.ClassTag import scala.reflect.runtime.universe._ -class ScoringServiceJsonProtocol(model: Model) { +object ScoringServiceJsonProtocol { implicit object ModelMetaDataFormat extends JsonFormat[ModelMetaData] { override def write(obj: ModelMetaData): JsValue = { @@ -53,168 +53,5 @@ class ScoringServiceJsonProtocol(model: Model) { Field(name, value) } } - - implicit object DataInputFormat extends JsonFormat[Seq[Array[Any]]] { - - //don't need this method. just there to satisfy the API. - override def write(obj: Seq[Array[Any]]): JsValue = ??? - - override def read(json: JsValue): Seq[Array[Any]] = { - val records = json.asJsObject.getFields("records") match { - case Seq(JsArray(records)) => records - case x => deserializationError(s"Expected array of records but got $x") - } - decodeRecords(records) - } - } - - implicit object DataTypeJsonFormat extends JsonFormat[Any] { - override def write(obj: Any): JsValue = { - obj match { - case n: Int => new JsNumber(n) - case n: Long => new JsNumber(n) - case n: Float => new JsNumber(BigDecimal(n)) - case n: Double => new JsNumber(n) - case s: String => new JsString(s) - case s: Boolean => JsBoolean(s) - case dt: DateTime => JsString(org.joda.time.format.ISODateTimeFormat.dateTime.print(dt)) - case m: Map[_, _] @unchecked => mapToJson(m) - case v: List[_] => listToJson(v) - case v: Array[_] => listToJson(v.toList) - case v: Vector[_] => listToJson(v.toList) - case v: ArrayBuffer[_] @unchecked => listToJson(v.toList) - case n: java.lang.Long => new JsNumber(n.longValue()) - // case null => JsNull Consciously not writing nulls, may need to change, but for now it may catch bugs - case unk => - val name: String = if (unk != null) { - unk.getClass.getName - } - else { - "null" - } - serializationError("Cannot serialize " + name) - } - } - - override def read(json: JsValue): Any = { - json match { - case JsNumber(n) if n.isValidInt => n.intValue() - case JsNumber(n) if n.isValidLong => n.longValue() - case JsNumber(n) if n.isValidFloat => n.floatValue() - case JsNumber(n) => n.doubleValue() - case JsBoolean(b) => b - case JsString(s) => s - case JsArray(v) => v.map(x => read(x)) - case obj: JsObject => obj.fields.map { - case (a, JsArray(v)) => (a, v.map(x => read(x))) - case (a, JsNumber(b)) => (a, b) - } - case unk => deserializationError("Cannot deserialize " + unk.getClass.getName) - } - } - } - - implicit object DataOutputFormat extends JsonFormat[Array[Map[String, Any]]] { - - override def write(obj: Array[Map[String, Any]]): JsValue = { - val modelMetadata = model.modelMetadata() - JsObject("data" -> new JsArray(obj.map(output => DataTypeJsonFormat.write(output)).toList)) - } - - //don't need this method. just there to satisfy the API. - override def read(json: JsValue): Array[Map[String, Any]] = ??? - } - - def decodeRecords(records: List[JsValue]): Seq[Array[Any]] = { - val decodedRecords: Seq[Map[String, Any]] = records.map { record => - record match { - case JsObject(fields) => - val decodedRecord: Map[String, Any] = for ((feature, value) <- fields) yield (feature, decodeJValue(value)) - decodedRecord - } - } - var features: Seq[Array[Any]] = Seq[Array[Any]]() - decodedRecords.foreach(decodedRecord => { - val obsColumns = model.input() - val featureArray = new Array[Any](obsColumns.length) - if (decodedRecord.size != featureArray.length) { - throw new IllegalArgumentException( - "Size of input record is not equal to number of observation columns that model was trained on:\n" + - s"""Expected columns are: [${obsColumns.mkString(",")}]""" - ) - } - decodedRecord.foreach({ - case (name, value) => { - var counter = 0 - var found = false - while (counter < obsColumns.length && !found) { - if (obsColumns(counter).name != name) { - counter = counter + 1 - } - else { - featureArray(counter) = value - found = true - } - } - if (!found) { - throw new IllegalArgumentException( - s"""$name was not found in list of Observation Columns that model was trained on: [${obsColumns.mkString(",")}]""" - ) - } - - } - }) - features = features :+ featureArray - }) - features - } - - def decodeJValue(v: JsValue): Any = { - v match { - case JsString(s) => s - case JsNumber(n) => n.toDouble - case JsArray(items) => for (item <- items) yield decodeJValue(item) - case JsNull => null - case JsObject(fields) => - val decodedValue: Map[String, Any] = for ((feature, value) <- fields) yield (feature, decodeJValue(value)) - decodedValue - case x => deserializationError(s"Unexpected JSON type in record $x") - } - } - - private def mapToJson[K <: Any, V <: Any](m: Map[K, V]): JsObject = { - require(m != null, s"Scoring service cannot serialize null to JSON") - val jsMap: Map[String, JsValue] = m.map { - case (x) => x match { - case (k: String, n: Double) => (k, n.toJson) - case (k: String, n: Int) => (k, n.toJson) - case (k: String, n: Long) => (k, n.toJson) - case (k: String, n: Float) => (k, n.toJson) - case (k: String, str: String) => (k, JsString(str)) - case (k: String, list: List[_]) => (k, listToJson(list)) - case (k: String, array: Array[_]) => (k, listToJson(array.toList)) - case (k: String, vector: Vector[_]) => (k, listToJson(vector.toList)) - case unk => serializationError(s"Scoring service cannot serialize ${unk.getClass.getName} to JSON") - } - } - JsObject(jsMap) - } - - private def listToJson(list: List[Any]): JsArray = { - require(list != null, s"Scoring service cannot serialize null to JSON") - val jsElements = list.map { - case n: Double => n.toJson - case n: Int => n.toJson - case n: Long => n.toJson - case n: Float => n.toJson - case str: String => str.toJson - case map: Map[_, _] @unchecked => mapToJson(map) - case list: List[_] => listToJson(list) - case arr: Array[_] => listToJson(arr.toList) - case vector: Vector[_] => listToJson(vector.toList) - case unk => serializationError(s"Scoring service cannot serialize ${unk.getClass.getName} to Json") - } - new JsArray(jsElements) - } } diff --git a/src/test/scala/org/trustedanalytics/ScoringServiceJsonProtocolTest.scala b/src/test/scala/org/trustedanalytics/ScoringServiceJsonProtocolTest.scala index 395138d..25e45e5 100644 --- a/src/test/scala/org/trustedanalytics/ScoringServiceJsonProtocolTest.scala +++ b/src/test/scala/org/trustedanalytics/ScoringServiceJsonProtocolTest.scala @@ -15,10 +15,10 @@ */ package org.trustedanalytics -import org.trustedanalytics.scoring.ScoringServiceJsonProtocol -import org.scalatest.{ Matchers, WordSpec } +import org.trustedanalytics.scoring.{DataOutputFormatJsonProtocol, ScoringServiceJsonProtocol} +import org.scalatest.{Matchers, WordSpec} import spray.json._ -import org.trustedanalytics.scoring.interfaces.{ ModelMetaData, Field, Model } +import org.trustedanalytics.scoring.interfaces.{Field, Model, ModelMetaData} import scala.collection.immutable.Map @@ -39,7 +39,7 @@ class ScoringServiceJsonProtocolTest extends WordSpec with Matchers { override def score(row: Array[Any]): Array[Any] = ??? } - val jsonFormat = new ScoringServiceJsonProtocol(model) + val jsonFormat = new DataOutputFormatJsonProtocol(model) import jsonFormat.DataInputFormat import jsonFormat.DataOutputFormat