Skip to content

Commit

Permalink
DPNG-12495: revise model feature on 0.7.4 (#9)
Browse files Browse the repository at this point in the history
* revise model feature on 0.7.4

* cleanup
  • Loading branch information
jitendra42 authored and rodorad committed Jan 7, 2017
1 parent dc6a682 commit 2e28f1f
Show file tree
Hide file tree
Showing 6 changed files with 476 additions and 325 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package org.trustedanalytics.scoring

import org.joda.time.DateTime
import org.trustedanalytics.scoring.interfaces.Model
import spray.json.{JsString, _}

import scala.collection.immutable.Map
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe._
import spray.json.DefaultJsonProtocol._

class DataOutputFormatJsonProtocol(model: Model) {

implicit object DataInputFormat extends JsonFormat[Seq[Array[Any]]] {

//don't need this method. just there to satisfy the API.
override def write(obj: Seq[Array[Any]]): JsValue = ???

override def read(json: JsValue): Seq[Array[Any]] = {
val records = json.asJsObject.getFields("records") match {
case Seq(JsArray(records)) => records
case x => deserializationError(s"Expected array of records but got $x")
}
decodeRecords(records)
}
}

implicit object DataOutputFormat extends JsonFormat[Array[Map[String, Any]]] {

override def write(obj: Array[Map[String, Any]]): JsValue = {
val modelMetadata = model.modelMetadata()
JsObject("data" -> new JsArray(obj.map(output => DataTypeJsonFormat.write(output)).toList))
}

//don't need this method. just there to satisfy the API.
override def read(json: JsValue): Array[Map[String, Any]] = ???
}

def decodeRecords(records: List[JsValue]): Seq[Array[Any]] = {
val decodedRecords: Seq[Map[String, Any]] = records.map { record =>
record match {
case JsObject(fields) =>
val decodedRecord: Map[String, Any] = for ((feature, value) <- fields) yield (feature, decodeJValue(value))
decodedRecord
}
}
var features: Seq[Array[Any]] = Seq[Array[Any]]()
decodedRecords.foreach(decodedRecord => {
val obsColumns = model.input()
val featureArray = new Array[Any](obsColumns.length)
if (decodedRecord.size != featureArray.length) {
throw new IllegalArgumentException(
"Size of input record is not equal to number of observation columns that model was trained on:\n" +
s"""Expected columns are: [${obsColumns.mkString(",")}]"""
)
}
decodedRecord.foreach({
case (name, value) => {
var counter = 0
var found = false
while (counter < obsColumns.length && !found) {
if (obsColumns(counter).name != name) {
counter = counter + 1
}
else {
featureArray(counter) = value
found = true
}
}
if (!found) {
throw new IllegalArgumentException(
s"""$name was not found in list of Observation Columns that model was trained on: [${obsColumns.mkString(",")}]"""
)
}

}
})
features = features :+ featureArray
})
features
}


def decodeJValue(v: JsValue): Any = {
v match {
case JsString(s) => s
case JsNumber(n) => n.toDouble
case JsArray(items) => for (item <- items) yield decodeJValue(item)
case JsNull => null
case JsObject(fields) =>
val decodedValue: Map[String, Any] = for ((feature, value) <- fields) yield (feature, decodeJValue(value))
decodedValue
case x => deserializationError(s"Unexpected JSON type in record $x")
}
}

private def mapToJson[K <: Any, V <: Any](m: Map[K, V]): JsObject = {
require(m != null, s"Scoring service cannot serialize null to JSON")
val jsMap: Map[String, JsValue] = m.map {
case (x) => x match {
case (k: String, n: Double) => (k, n.toJson)
case (k: String, n: Int) => (k, n.toJson)
case (k: String, n: Long) => (k, n.toJson)
case (k: String, n: Float) => (k, n.toJson)
case (k: String, str: String) => (k, JsString(str))
case (k: String, list: List[_]) => (k, listToJson(list))
case (k: String, array: Array[_]) => (k, listToJson(array.toList))
case (k: String, vector: Vector[_]) => (k, listToJson(vector.toList))
case unk => serializationError(s"Scoring service cannot serialize ${unk.getClass.getName} to JSON")
}
}
JsObject(jsMap)
}

implicit object DataTypeJsonFormat extends JsonFormat[Any] {
override def write(obj: Any): JsValue = {
obj match {
case n: Int => new JsNumber(n)
case n: Long => new JsNumber(n)
case n: Float => new JsNumber(BigDecimal(n))
case n: Double => new JsNumber(n)
case s: String => new JsString(s)
case s: Boolean => JsBoolean(s)
case dt: DateTime => JsString(org.joda.time.format.ISODateTimeFormat.dateTime.print(dt))
case m: Map[_, _] @unchecked => mapToJson(m)
case v: List[_] => listToJson(v)
case v: Array[_] => listToJson(v.toList)
case v: Vector[_] => listToJson(v.toList)
case v: ArrayBuffer[_] @unchecked => listToJson(v.toList)
case n: java.lang.Long => new JsNumber(n.longValue())
// case null => JsNull Consciously not writing nulls, may need to change, but for now it may catch bugs
case unk =>
val name: String = if (unk != null) {
unk.getClass.getName
}
else {
"null"
}
serializationError("Cannot serialize " + name)
}
}

override def read(json: JsValue): Any = {
json match {
case JsNumber(n) if n.isValidInt => n.intValue()
case JsNumber(n) if n.isValidLong => n.longValue()
case JsNumber(n) if n.isValidFloat => n.floatValue()
case JsNumber(n) => n.doubleValue()
case JsBoolean(b) => b
case JsString(s) => s
case JsArray(v) => v.map(x => read(x))
case obj: JsObject => obj.fields.map {
case (a, JsArray(v)) => (a, v.map(x => read(x)))
case (a, JsNumber(b)) => (a, b)
}
case unk => deserializationError("Cannot deserialize " + unk.getClass.getName)
}
}
}

private def listToJson(list: List[Any]): JsArray = {
require(list != null, s"Scoring service cannot serialize null to JSON")
val jsElements = list.map {
case n: Double => n.toJson
case n: Int => n.toJson
case n: Long => n.toJson
case n: Float => n.toJson
case str: String => str.toJson
case map: Map[_, _] @unchecked => mapToJson(map)
case list: List[_] => listToJson(list)
case arr: Array[_] => listToJson(arr.toList)
case vector: Vector[_] => listToJson(vector.toList)
case unk => serializationError(s"Scoring service cannot serialize ${unk.getClass.getName} to Json")
}
new JsArray(jsElements)
}


}
158 changes: 158 additions & 0 deletions src/main/scala/org/trustedanalytics/scoring/ScoringEngineHelper.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/**
* Copyright (c) 2015 Intel Corporation 
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*       http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.trustedanalytics.scoring

import java.io.{ FileOutputStream, File, FileInputStream }

import akka.actor.{ ActorSystem, Props }
import akka.io.IO
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{ FileSystem, Path }
import org.trustedanalytics.hadoop.config.client.oauth.TapOauthToken
import spray.can.Http
import akka.pattern.ask
import akka.util.Timeout
import scala.concurrent.duration._
import com.typesafe.config.{ Config, ConfigFactory }
import scala.reflect.ClassTag
import org.trustedanalytics.scoring.interfaces.Model
import org.trustedanalytics.hadoop.config.client.helper.Hdfs
import java.net.URI
import org.apache.commons.io.FileUtils

import org.apache.http.message.BasicNameValuePair
import org.apache.http.auth.{ AuthScope, UsernamePasswordCredentials }
import org.apache.http.impl.client.{ BasicCredentialsProvider, HttpClientBuilder }
import org.apache.http.client.methods.{ HttpPost, CloseableHttpResponse }
import org.apache.http.HttpHost
import org.apache.http.client.config.RequestConfig
import org.apache.http.client.entity.UrlEncodedFormEntity
import java.util.{ ArrayList => JArrayList }
import spray.json._
import org.trustedanalytics.model.archive.format.ModelArchiveFormat
import org.slf4j.LoggerFactory

object ScoringEngineHelper {
private val logger = LoggerFactory.getLogger(this.getClass)
val config = ConfigFactory.load(this.getClass.getClassLoader)
/**
*
* @param model Original model
* @param revisedModel Revised model
* @return true if model compatible i.e model type is same and model input/output parameters are same
* else returns false
*/
def isModelCompatible(model: Model, revisedModel: Model): Boolean = {
model.modelMetadata().modelType == revisedModel.modelMetadata().modelType &&
model.input().deep == revisedModel.input().deep &&
model.output().deep == revisedModel.output().deep
}

def getModel(modelFilePath: String): Model = {
var tempMarFile: File = null
var marFilePath = modelFilePath
try {
if (marFilePath.startsWith("hdfs:/")) {
if (!marFilePath.startsWith("hdfs://")) {
val relativePath = marFilePath.substring(marFilePath.indexOf("hdfs:") + 6)
marFilePath = "hdfs://" + relativePath
}

val hdfsFileSystem = try {
val token = new TapOauthToken(getJwtToken())
logger.info(s"Successfully retreived a token for user ${token.getUserName}")
Hdfs.newInstance().createFileSystem(token)
}
catch {
case t: Throwable =>
t.printStackTrace()
logger.info("Failed to create HDFS instance using hadoop-library. Default to FileSystem")
org.apache.hadoop.fs.FileSystem.get(new URI(marFilePath), new Configuration())
}
tempMarFile = File.createTempFile("model", ".mar")
hdfsFileSystem.copyToLocalFile(false, new Path(marFilePath), new Path(tempMarFile.getAbsolutePath))
marFilePath = tempMarFile.getAbsolutePath
}
logger.info("calling ModelArchiveFormat to get the model")
sys.addShutdownHook(FileUtils.deleteQuietly(tempMarFile)) // Delete temporary directory on exit
ModelArchiveFormat.read(new File(marFilePath), this.getClass.getClassLoader, None)
}
finally {
FileUtils.deleteQuietly(tempMarFile)
}
}

def getJwtToken(): String = {

val query = s"http://${System.getenv("UAA_URI")}/oauth/token"
val headers = List(("Accept", "application/json"))
val data = List(("username", System.getenv("FS_TECHNICAL_USER_NAME")), ("password", System.getenv("FS_TECHNICAL_USER_PASSWORD")), ("grant_type", "password"))
val credentialsProvider = new BasicCredentialsProvider()
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(System.getenv("UAA_CLIENT_NAME"), System.getenv("UAA_CLIENT_PASSWORD")))

// TODO: This method uses Apache HttpComponents HttpClient as spray-http library does not support proxy over https
val (proxyHostConfigString, proxyPortConfigString) = ("https.proxyHost", "https.proxyPort")
val httpClient = HttpClientBuilder.create().setDefaultCredentialsProvider(credentialsProvider).build()
try {
val proxy = (sys.props.contains(proxyHostConfigString), sys.props.contains(proxyPortConfigString)) match {
case (true, true) => Some(new HttpHost(sys.props.get(proxyHostConfigString).get, sys.props.get(proxyPortConfigString).get.toInt))
case _ => None
}

val config = {
val cfg = RequestConfig.custom().setConnectTimeout(30)
if (proxy.isDefined)
cfg.setProxy(proxy.get).build()
else cfg.build()
}

val request = new HttpPost(query)
val nvps = new JArrayList[BasicNameValuePair]
data.foreach { case (k, v) => nvps.add(new BasicNameValuePair(k, v)) }
request.setEntity(new UrlEncodedFormEntity(nvps))

for ((headerTag, headerData) <- headers)
request.addHeader(headerTag, headerData)
request.setConfig(config)

var response: Option[CloseableHttpResponse] = None
try {
response = Some(httpClient.execute(request))
val inputStream = response.get.getEntity().getContent
val result = scala.io.Source.fromInputStream(inputStream).getLines().mkString("\n")
logger.info(s"Response From UAA Server is $result")
result.parseJson.asJsObject().getFields("access_token") match {
case values => values(0).asInstanceOf[JsString].value
}
}
catch {
case ex: Throwable =>
error(s"Error executing request ${ex.getMessage}")
// We need this exception to be thrown as this is a generic http request method and let caller handle.
throw ex
}
finally {
if (response.isDefined)
response.get.close()
}
}
finally {
httpClient.close()
}
}

}
Loading

0 comments on commit 2e28f1f

Please sign in to comment.