-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
84 lines (70 loc) · 2.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import joblib
import pandas as pd
from fastapi import FastAPI, Body
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, NonNegativeInt
from format import Workclass, Education, Marital, Occupation, Sex, Relationship, Race, Country
from model import inference
# Instantiate the app.
app = FastAPI(
title="Census Income API",
description="An API that classifies individuals into high salary (>50k) and low salary (<=50k) from Census Bureau data.",
version="1.0.0",
)
# Helper function for alias names
def hyphenate(string: str) -> str:
return string.replace("_", "-")
# Helper class with TypeHinting
class Input(BaseModel):
Body(model_config=ConfigDict(alias_generator=hyphenate))
class Config:
use_enum_values = True
age: PositiveInt
workclass: Workclass
fnlgt: PositiveInt
education: Education
education_num: PositiveInt = Field(alias="education-num")
marital_status: Marital = Field(alias="marital-status")
occupation: Occupation
relationship: Relationship
race: Race
sex: Sex
capital_gain: NonNegativeInt = Field(alias="capital-gain")
capital_loss: NonNegativeInt = Field(alias="capital-loss")
hours_per_week: NonNegativeInt = Field(alias="hours-per-week")
native_country: Country = Field(alias="native-country")
Body(
examples=
[
{
'age': 38,
'workclass': 'State-gov',
'fnlgt': 77516,
'education': 'Bachelors',
'education-num': 13,
'marital-status': 'Never-married',
'occupation': 'Adm-clerical',
'relationship': 'Not-in-family',
'race': 'White',
'sex': 'Male',
'capital-gain': 2174,
'capital-loss': 0,
'hours-per-week': 40,
'native-country': 'United-States'
}
]
)
# Load model on startup
@app.on_event('startup')
async def load_model():
global model
model = joblib.load("model/rfc_model.pkl")
# Root GET Endpoint.
@app.get("/")
async def say_hello():
return {"greeting": "Hello World!"}
# Inference POST Endpoint
@app.post("/predict")
async def predict_from_data(data: Input):
input_data = pd.DataFrame([data.dict()])
input_data.columns = input_data.columns.str.replace("_", "-")
return {"prediction": inference(model, input_data)}