-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
DPNG-12495: revise model feature on 0.7.4 (#9)
* revise model feature on 0.7.4 * cleanup
- Loading branch information
1 parent
dc6a682
commit 2e28f1f
Showing
6 changed files
with
476 additions
and
325 deletions.
There are no files selected for viewing
179 changes: 179 additions & 0 deletions
179
src/main/scala/org/trustedanalytics/scoring/DataOutputFormatJsonProtocol.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
package org.trustedanalytics.scoring | ||
|
||
import org.joda.time.DateTime | ||
import org.trustedanalytics.scoring.interfaces.Model | ||
import spray.json.{JsString, _} | ||
|
||
import scala.collection.immutable.Map | ||
import scala.collection.mutable.ArrayBuffer | ||
import scala.reflect.runtime.universe._ | ||
import spray.json.DefaultJsonProtocol._ | ||
|
||
class DataOutputFormatJsonProtocol(model: Model) { | ||
|
||
implicit object DataInputFormat extends JsonFormat[Seq[Array[Any]]] { | ||
|
||
//don't need this method. just there to satisfy the API. | ||
override def write(obj: Seq[Array[Any]]): JsValue = ??? | ||
|
||
override def read(json: JsValue): Seq[Array[Any]] = { | ||
val records = json.asJsObject.getFields("records") match { | ||
case Seq(JsArray(records)) => records | ||
case x => deserializationError(s"Expected array of records but got $x") | ||
} | ||
decodeRecords(records) | ||
} | ||
} | ||
|
||
implicit object DataOutputFormat extends JsonFormat[Array[Map[String, Any]]] { | ||
|
||
override def write(obj: Array[Map[String, Any]]): JsValue = { | ||
val modelMetadata = model.modelMetadata() | ||
JsObject("data" -> new JsArray(obj.map(output => DataTypeJsonFormat.write(output)).toList)) | ||
} | ||
|
||
//don't need this method. just there to satisfy the API. | ||
override def read(json: JsValue): Array[Map[String, Any]] = ??? | ||
} | ||
|
||
def decodeRecords(records: List[JsValue]): Seq[Array[Any]] = { | ||
val decodedRecords: Seq[Map[String, Any]] = records.map { record => | ||
record match { | ||
case JsObject(fields) => | ||
val decodedRecord: Map[String, Any] = for ((feature, value) <- fields) yield (feature, decodeJValue(value)) | ||
decodedRecord | ||
} | ||
} | ||
var features: Seq[Array[Any]] = Seq[Array[Any]]() | ||
decodedRecords.foreach(decodedRecord => { | ||
val obsColumns = model.input() | ||
val featureArray = new Array[Any](obsColumns.length) | ||
if (decodedRecord.size != featureArray.length) { | ||
throw new IllegalArgumentException( | ||
"Size of input record is not equal to number of observation columns that model was trained on:\n" + | ||
s"""Expected columns are: [${obsColumns.mkString(",")}]""" | ||
) | ||
} | ||
decodedRecord.foreach({ | ||
case (name, value) => { | ||
var counter = 0 | ||
var found = false | ||
while (counter < obsColumns.length && !found) { | ||
if (obsColumns(counter).name != name) { | ||
counter = counter + 1 | ||
} | ||
else { | ||
featureArray(counter) = value | ||
found = true | ||
} | ||
} | ||
if (!found) { | ||
throw new IllegalArgumentException( | ||
s"""$name was not found in list of Observation Columns that model was trained on: [${obsColumns.mkString(",")}]""" | ||
) | ||
} | ||
|
||
} | ||
}) | ||
features = features :+ featureArray | ||
}) | ||
features | ||
} | ||
|
||
|
||
def decodeJValue(v: JsValue): Any = { | ||
v match { | ||
case JsString(s) => s | ||
case JsNumber(n) => n.toDouble | ||
case JsArray(items) => for (item <- items) yield decodeJValue(item) | ||
case JsNull => null | ||
case JsObject(fields) => | ||
val decodedValue: Map[String, Any] = for ((feature, value) <- fields) yield (feature, decodeJValue(value)) | ||
decodedValue | ||
case x => deserializationError(s"Unexpected JSON type in record $x") | ||
} | ||
} | ||
|
||
private def mapToJson[K <: Any, V <: Any](m: Map[K, V]): JsObject = { | ||
require(m != null, s"Scoring service cannot serialize null to JSON") | ||
val jsMap: Map[String, JsValue] = m.map { | ||
case (x) => x match { | ||
case (k: String, n: Double) => (k, n.toJson) | ||
case (k: String, n: Int) => (k, n.toJson) | ||
case (k: String, n: Long) => (k, n.toJson) | ||
case (k: String, n: Float) => (k, n.toJson) | ||
case (k: String, str: String) => (k, JsString(str)) | ||
case (k: String, list: List[_]) => (k, listToJson(list)) | ||
case (k: String, array: Array[_]) => (k, listToJson(array.toList)) | ||
case (k: String, vector: Vector[_]) => (k, listToJson(vector.toList)) | ||
case unk => serializationError(s"Scoring service cannot serialize ${unk.getClass.getName} to JSON") | ||
} | ||
} | ||
JsObject(jsMap) | ||
} | ||
|
||
implicit object DataTypeJsonFormat extends JsonFormat[Any] { | ||
override def write(obj: Any): JsValue = { | ||
obj match { | ||
case n: Int => new JsNumber(n) | ||
case n: Long => new JsNumber(n) | ||
case n: Float => new JsNumber(BigDecimal(n)) | ||
case n: Double => new JsNumber(n) | ||
case s: String => new JsString(s) | ||
case s: Boolean => JsBoolean(s) | ||
case dt: DateTime => JsString(org.joda.time.format.ISODateTimeFormat.dateTime.print(dt)) | ||
case m: Map[_, _] @unchecked => mapToJson(m) | ||
case v: List[_] => listToJson(v) | ||
case v: Array[_] => listToJson(v.toList) | ||
case v: Vector[_] => listToJson(v.toList) | ||
case v: ArrayBuffer[_] @unchecked => listToJson(v.toList) | ||
case n: java.lang.Long => new JsNumber(n.longValue()) | ||
// case null => JsNull Consciously not writing nulls, may need to change, but for now it may catch bugs | ||
case unk => | ||
val name: String = if (unk != null) { | ||
unk.getClass.getName | ||
} | ||
else { | ||
"null" | ||
} | ||
serializationError("Cannot serialize " + name) | ||
} | ||
} | ||
|
||
override def read(json: JsValue): Any = { | ||
json match { | ||
case JsNumber(n) if n.isValidInt => n.intValue() | ||
case JsNumber(n) if n.isValidLong => n.longValue() | ||
case JsNumber(n) if n.isValidFloat => n.floatValue() | ||
case JsNumber(n) => n.doubleValue() | ||
case JsBoolean(b) => b | ||
case JsString(s) => s | ||
case JsArray(v) => v.map(x => read(x)) | ||
case obj: JsObject => obj.fields.map { | ||
case (a, JsArray(v)) => (a, v.map(x => read(x))) | ||
case (a, JsNumber(b)) => (a, b) | ||
} | ||
case unk => deserializationError("Cannot deserialize " + unk.getClass.getName) | ||
} | ||
} | ||
} | ||
|
||
private def listToJson(list: List[Any]): JsArray = { | ||
require(list != null, s"Scoring service cannot serialize null to JSON") | ||
val jsElements = list.map { | ||
case n: Double => n.toJson | ||
case n: Int => n.toJson | ||
case n: Long => n.toJson | ||
case n: Float => n.toJson | ||
case str: String => str.toJson | ||
case map: Map[_, _] @unchecked => mapToJson(map) | ||
case list: List[_] => listToJson(list) | ||
case arr: Array[_] => listToJson(arr.toList) | ||
case vector: Vector[_] => listToJson(vector.toList) | ||
case unk => serializationError(s"Scoring service cannot serialize ${unk.getClass.getName} to Json") | ||
} | ||
new JsArray(jsElements) | ||
} | ||
|
||
|
||
} |
158 changes: 158 additions & 0 deletions
158
src/main/scala/org/trustedanalytics/scoring/ScoringEngineHelper.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/** | ||
* Copyright (c) 2015 Intel Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.trustedanalytics.scoring | ||
|
||
import java.io.{ FileOutputStream, File, FileInputStream } | ||
|
||
import akka.actor.{ ActorSystem, Props } | ||
import akka.io.IO | ||
import org.apache.hadoop.conf.Configuration | ||
import org.apache.hadoop.fs.{ FileSystem, Path } | ||
import org.trustedanalytics.hadoop.config.client.oauth.TapOauthToken | ||
import spray.can.Http | ||
import akka.pattern.ask | ||
import akka.util.Timeout | ||
import scala.concurrent.duration._ | ||
import com.typesafe.config.{ Config, ConfigFactory } | ||
import scala.reflect.ClassTag | ||
import org.trustedanalytics.scoring.interfaces.Model | ||
import org.trustedanalytics.hadoop.config.client.helper.Hdfs | ||
import java.net.URI | ||
import org.apache.commons.io.FileUtils | ||
|
||
import org.apache.http.message.BasicNameValuePair | ||
import org.apache.http.auth.{ AuthScope, UsernamePasswordCredentials } | ||
import org.apache.http.impl.client.{ BasicCredentialsProvider, HttpClientBuilder } | ||
import org.apache.http.client.methods.{ HttpPost, CloseableHttpResponse } | ||
import org.apache.http.HttpHost | ||
import org.apache.http.client.config.RequestConfig | ||
import org.apache.http.client.entity.UrlEncodedFormEntity | ||
import java.util.{ ArrayList => JArrayList } | ||
import spray.json._ | ||
import org.trustedanalytics.model.archive.format.ModelArchiveFormat | ||
import org.slf4j.LoggerFactory | ||
|
||
object ScoringEngineHelper { | ||
private val logger = LoggerFactory.getLogger(this.getClass) | ||
val config = ConfigFactory.load(this.getClass.getClassLoader) | ||
/** | ||
* | ||
* @param model Original model | ||
* @param revisedModel Revised model | ||
* @return true if model compatible i.e model type is same and model input/output parameters are same | ||
* else returns false | ||
*/ | ||
def isModelCompatible(model: Model, revisedModel: Model): Boolean = { | ||
model.modelMetadata().modelType == revisedModel.modelMetadata().modelType && | ||
model.input().deep == revisedModel.input().deep && | ||
model.output().deep == revisedModel.output().deep | ||
} | ||
|
||
def getModel(modelFilePath: String): Model = { | ||
var tempMarFile: File = null | ||
var marFilePath = modelFilePath | ||
try { | ||
if (marFilePath.startsWith("hdfs:/")) { | ||
if (!marFilePath.startsWith("hdfs://")) { | ||
val relativePath = marFilePath.substring(marFilePath.indexOf("hdfs:") + 6) | ||
marFilePath = "hdfs://" + relativePath | ||
} | ||
|
||
val hdfsFileSystem = try { | ||
val token = new TapOauthToken(getJwtToken()) | ||
logger.info(s"Successfully retreived a token for user ${token.getUserName}") | ||
Hdfs.newInstance().createFileSystem(token) | ||
} | ||
catch { | ||
case t: Throwable => | ||
t.printStackTrace() | ||
logger.info("Failed to create HDFS instance using hadoop-library. Default to FileSystem") | ||
org.apache.hadoop.fs.FileSystem.get(new URI(marFilePath), new Configuration()) | ||
} | ||
tempMarFile = File.createTempFile("model", ".mar") | ||
hdfsFileSystem.copyToLocalFile(false, new Path(marFilePath), new Path(tempMarFile.getAbsolutePath)) | ||
marFilePath = tempMarFile.getAbsolutePath | ||
} | ||
logger.info("calling ModelArchiveFormat to get the model") | ||
sys.addShutdownHook(FileUtils.deleteQuietly(tempMarFile)) // Delete temporary directory on exit | ||
ModelArchiveFormat.read(new File(marFilePath), this.getClass.getClassLoader, None) | ||
} | ||
finally { | ||
FileUtils.deleteQuietly(tempMarFile) | ||
} | ||
} | ||
|
||
def getJwtToken(): String = { | ||
|
||
val query = s"http://${System.getenv("UAA_URI")}/oauth/token" | ||
val headers = List(("Accept", "application/json")) | ||
val data = List(("username", System.getenv("FS_TECHNICAL_USER_NAME")), ("password", System.getenv("FS_TECHNICAL_USER_PASSWORD")), ("grant_type", "password")) | ||
val credentialsProvider = new BasicCredentialsProvider() | ||
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(System.getenv("UAA_CLIENT_NAME"), System.getenv("UAA_CLIENT_PASSWORD"))) | ||
|
||
// TODO: This method uses Apache HttpComponents HttpClient as spray-http library does not support proxy over https | ||
val (proxyHostConfigString, proxyPortConfigString) = ("https.proxyHost", "https.proxyPort") | ||
val httpClient = HttpClientBuilder.create().setDefaultCredentialsProvider(credentialsProvider).build() | ||
try { | ||
val proxy = (sys.props.contains(proxyHostConfigString), sys.props.contains(proxyPortConfigString)) match { | ||
case (true, true) => Some(new HttpHost(sys.props.get(proxyHostConfigString).get, sys.props.get(proxyPortConfigString).get.toInt)) | ||
case _ => None | ||
} | ||
|
||
val config = { | ||
val cfg = RequestConfig.custom().setConnectTimeout(30) | ||
if (proxy.isDefined) | ||
cfg.setProxy(proxy.get).build() | ||
else cfg.build() | ||
} | ||
|
||
val request = new HttpPost(query) | ||
val nvps = new JArrayList[BasicNameValuePair] | ||
data.foreach { case (k, v) => nvps.add(new BasicNameValuePair(k, v)) } | ||
request.setEntity(new UrlEncodedFormEntity(nvps)) | ||
|
||
for ((headerTag, headerData) <- headers) | ||
request.addHeader(headerTag, headerData) | ||
request.setConfig(config) | ||
|
||
var response: Option[CloseableHttpResponse] = None | ||
try { | ||
response = Some(httpClient.execute(request)) | ||
val inputStream = response.get.getEntity().getContent | ||
val result = scala.io.Source.fromInputStream(inputStream).getLines().mkString("\n") | ||
logger.info(s"Response From UAA Server is $result") | ||
result.parseJson.asJsObject().getFields("access_token") match { | ||
case values => values(0).asInstanceOf[JsString].value | ||
} | ||
} | ||
catch { | ||
case ex: Throwable => | ||
error(s"Error executing request ${ex.getMessage}") | ||
// We need this exception to be thrown as this is a generic http request method and let caller handle. | ||
throw ex | ||
} | ||
finally { | ||
if (response.isDefined) | ||
response.get.close() | ||
} | ||
} | ||
finally { | ||
httpClient.close() | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.