-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
190 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/usr/bin/env Rscript | ||
|
||
install.packages(c("dplyr","data.table","tidyr"), repos="https://cran.r-project.org") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
FROM r-base:4.3.3 | ||
|
||
|
||
COPY data /data | ||
COPY *.R / | ||
COPY *.rds / | ||
|
||
RUN Rscript packages.R | ||
|
||
ENTRYPOINT ["Rscript", "run.R"] | ||
CMD ["predict", "/data/fake_data.csv"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
#!/usr/bin/env Rscript | ||
|
||
library(dplyr) | ||
library(tidyr) | ||
|
||
|
||
source("submission.R") | ||
|
||
print_usage <- function() { | ||
cat("Usage:\n") | ||
cat(" Rscript script.R predict INPUT_FILE [--output OUTPUT_FILE]\n") | ||
cat(" Rscript script.R score --prediction PREDICTION_FILE --ground_truth GROUND_TRUTH_FILE [--output OUTPUT_FILE]\n") | ||
} | ||
|
||
parse_arguments <- function() { | ||
args <- list() | ||
command_args <- commandArgs(trailingOnly = TRUE) | ||
if (length(command_args) > 0) { | ||
args$command <- command_args[1] | ||
|
||
if (is.null(args$command)) { | ||
stop("Error: No command provided.") | ||
} | ||
|
||
if (args$command == "predict") { | ||
args$input <- commandArgs(trailingOnly = TRUE)[2] | ||
args$output <- get_argument("--output") | ||
} else if (args$command == "score") { | ||
args$prediction <- get_argument("--prediction") | ||
args$ground_truth <- get_argument("--ground_truth") | ||
args$output <- get_argument("--output") | ||
} | ||
} else { | ||
stop("Error: No command provided. Run the script with predict or score.") | ||
} | ||
|
||
return(args) | ||
} | ||
|
||
get_argument <- function(arg_name) { | ||
if (arg_name %in% commandArgs(trailingOnly = TRUE)) { | ||
arg_index <- which(commandArgs(trailingOnly = TRUE) == arg_name) | ||
if (arg_index < length(commandArgs(trailingOnly = TRUE))) { | ||
return(commandArgs(trailingOnly = TRUE)[arg_index + 1]) | ||
} | ||
} | ||
return(NULL) | ||
} | ||
|
||
parse_and_run_predict <- function(args) { | ||
if (is.null(args$input)) { | ||
stop("Error: Please provide --input argument for prediction.") | ||
} | ||
|
||
cat("Processing input data for prediction from:", args$input, "\n") | ||
if (!is.null(args$output)) { | ||
cat("Output will be saved to:", args$output, "\n") | ||
} | ||
run_predict(args$input, args$output) | ||
} | ||
|
||
run_score <- function(args) { | ||
if (is.null(args$prediction) || is.null(args$ground_truth)) { | ||
stop("Error: Please provide --prediction and --ground_truth arguments for scoring.") | ||
} | ||
|
||
cat("Scoring predictions from:", args$prediction, "\n") | ||
cat("Ground truth data from:", args$ground_truth, "\n") | ||
if (!is.null(args$output)) { | ||
cat("Evaluation score will be saved to:", args$output, "\n") | ||
} | ||
# Call your submission function for scoring here | ||
} | ||
|
||
run_predict <- function(input_path, output=NULL) { | ||
if (is.null(output)) { | ||
output <- stdout() | ||
} | ||
|
||
|
||
# Read data from input file | ||
df <- read.csv(input_path, encoding="latin1") | ||
|
||
# Clean the data | ||
df <- clean_df(df) # Assuming clean_df is a function in the submission package | ||
|
||
# Make predictions | ||
predictions <- predict_outcomes(df) # Assuming predict_outcomes is a function in the submission package | ||
|
||
# Check if predictions have the required format | ||
stopifnot(ncol(predictions) == 2, | ||
all(c("nomem_encr", "prediction") %in% colnames(predictions))) | ||
|
||
# Write predictions to output file | ||
write.csv(predictions, output, row.names = FALSE) | ||
} | ||
|
||
|
||
# Main function | ||
main <- function() { | ||
args <- parse_arguments() | ||
|
||
if (args$command == "predict") { | ||
parse_and_run_predict(args) | ||
} else if (args$command == "score") { | ||
run_score(args) | ||
} else { | ||
stop("Error: Invalid command. Use 'predict' or 'score'.") | ||
} | ||
} | ||
|
||
# Call main function | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"dockerfile": "python.Dockerfile"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# edit the preprocessing function using the code you used for preprocesing the train data | ||
clean_df <- function(df){ | ||
# Process the input data to feed the model | ||
|
||
## Selecting variables | ||
keepcols = c('nomem_encr', 'birthyear_bg', 'gender_bg', 'burgstat_2020','oplmet_2020', 'cf20m454') | ||
|
||
df <- df %>% select(all_of(keepcols)) | ||
|
||
# imputing missing values with mode (for factors) or median (for interval variables) | ||
my_mode <- function(x) { | ||
x <-x[!is.na(x)] | ||
ux <- unique(x) | ||
tab <- tabulate(match(x, ux)) | ||
mode <- ux[tab == max(tab)] | ||
ifelse(length(mode) > 1, sample(mode, 1), mode) | ||
} | ||
|
||
df <- df %>% | ||
mutate(across(c(gender_bg, burgstat_2020, oplmet_2020, cf20m454), ~replace_na(., my_mode(.))), | ||
across(c(gender_bg, burgstat_2020, oplmet_2020, cf20m454), as.factor), | ||
across(birthyear_bg, ~replace_na(., median(., na.rm=TRUE)))) | ||
|
||
return(df) | ||
} | ||
|
||
|
||
|
||
# if necessary, edit the function so it returns predicted classes (1/0), not probabilities | ||
predict_outcomes <- function(df, model_path="./model.rds"){ | ||
# preprocess the holdout data | ||
df <- clean_df(df) | ||
ids <- select(df, nomem_encr) | ||
|
||
# Load the model | ||
model <- readRDS(model_path) | ||
|
||
# !if necessary, make edits to produce predicted classes | ||
# E.g. if you used glm() function to train a model, add 'type="response"' to get probabilities | ||
pred <- predict(model, df, type="response") | ||
#and then transform them into predicted classes | ||
pred <- ifelse(pred>0.5, 1, 0) | ||
|
||
# adding prediction column to id column | ||
ids$prediction<- pred | ||
|
||
return(ids) | ||
} | ||
|
||
|
||
# ######## do not edit this ############################ | ||
# df <- read.csv(args[1]) | ||
# predictions <- predict_holdout(df) | ||
# write.csv(predictions,"predictions.csv", row.names = FALSE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters