Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
qingzma committed Oct 19, 2023
1 parent 17fe681 commit f00892d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 25 deletions.
File renamed without changes.
File renamed without changes.
3 changes: 2 additions & 1 deletion dbestclient/cli/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def default(self, inp):
# go to the backend server
else:
# sqlExecutor = SqlExecutor(config)
# print(self.query)
print(self.query)
self.query = self.query.lower()
# self.query.replace(";",'')
self.sqlExecutor.execute(self.query)

Expand Down
62 changes: 40 additions & 22 deletions dbestclient/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def init_model_catalog(self):
print("start loading pre-existing models.")

with open(
self.config.get_config()["warehousedir"] + "/" + file_name, "rb"
self.config.get_config()[
"warehousedir"] + "/" + file_name, "rb"
) as f:
model = dill.load(f)
self.model_catalog.model_catalog[
Expand Down Expand Up @@ -111,7 +112,7 @@ def execute(self, sql):
# prepare the parser
if type(sql) == str:
self.parser = DBEstParser()
self.parser.parse(sql)
self.parser.parse(sql.lower())
elif type(sql) == DBEstParser:
self.parser = sql
else:
Expand Down Expand Up @@ -228,7 +229,7 @@ def execute(self, sql):
split_char=self.config.get_config()["csv_split_char"],
num_total_records=self.n_total_records,
)

if self.runtime_config["sampling_only"]:
print("sample is generated and saved, end.")
return
Expand Down Expand Up @@ -259,12 +260,15 @@ def execute(self, sql):
xys["data"], self.runtime_config, network_size="large"
)

qe_mdn = MdnQueryEngine(kdeModelWrapper, config=self.config.copy())
qe_mdn = MdnQueryEngine(
kdeModelWrapper, config=self.config.copy())

qe_mdn.serialize2warehouse(
self.config.get_config()["warehousedir"], self.runtime_config
self.config.get_config()[
"warehousedir"], self.runtime_config
)
self.model_catalog.add_model_wrapper(qe_mdn, self.runtime_config)
self.model_catalog.add_model_wrapper(
qe_mdn, self.runtime_config)

else: # if group by is involved in the query
if self.config.get_config()["reg_type"] == "qreg":
Expand All @@ -276,7 +280,8 @@ def execute(self, sql):
headers=table_header,
)

n_sample_point = get_group_count_from_df(xys, groupby_attribute)
n_sample_point = get_group_count_from_df(
xys, groupby_attribute)
groupby_model_wrapper = GroupByModelTrainer(
mdl,
tbl,
Expand Down Expand Up @@ -427,7 +432,8 @@ def execute(self, sql):
elif method.lower() == "stratified":
# check if this query could be served by frequency table only.

b_ft_only = parse_y_check_need_ft_only(usecols)
b_ft_only = parse_y_check_need_ft_only(
usecols)
# print("b_ft_only", b_ft_only)
if b_ft_only:
# print("to implement")
Expand Down Expand Up @@ -486,7 +492,8 @@ def execute(self, sql):
if (
not xheader_categorical
): # For WHERE clause without categorical equality
n_total_point.pop("if_contain_x_categorical")
n_total_point.pop(
"if_contain_x_categorical")
qe_mdn = MdnQueryEngineNoRange(
config=self.config.copy()
)
Expand Down Expand Up @@ -521,14 +528,16 @@ def execute(self, sql):
if method.lower() == "uniform":
if not n_total_point["if_contain_x_categorical"]:
if not self.config.get_config()["b_use_gg"]:
n_total_point.pop("if_contain_x_categorical")
n_total_point.pop(
"if_contain_x_categorical")
kdeModelWrapper = KdeModelTrainer(
mdl,
tbl,
xheader_continous[0],
yheader,
groupby_attribute=groupby_attribute,
groupby_values=list(n_total_point.keys()),
groupby_values=list(
n_total_point.keys()),
n_total_point=n_total_point,
x_min_value=-np.inf,
x_max_value=np.inf,
Expand All @@ -543,7 +552,8 @@ def execute(self, sql):
kdeModelWrapper, config=self.config.copy()
)
qe_mdn.serialize2warehouse(
self.config.get_config()["warehousedir"],
self.config.get_config()[
"warehousedir"],
self.runtime_config,
)
self.model_catalog.add_model_wrapper(
Expand All @@ -568,7 +578,8 @@ def execute(self, sql):
queryEngineBundle, self.runtime_config
)
queryEngineBundle.serialize2warehouse(
self.config.get_config()["warehousedir"],
self.config.get_config()[
"warehousedir"],
self.runtime_config,
)
else: # x has categorical attributes
Expand Down Expand Up @@ -640,15 +651,17 @@ def execute(self, sql):
runtime_config=self.runtime_config,
)
qe.serialize2warehouse(
self.config.get_config()["warehousedir"],
self.config.get_config()[
"warehousedir"],
self.runtime_config,
)
self.model_catalog.add_model_wrapper(
qe, self.runtime_config
)

qe.serialize2warehouse(
self.config.get_config()["warehousedir"],
self.config.get_config()[
"warehousedir"],
self.runtime_config,
)
self.model_catalog.add_model_wrapper(
Expand Down Expand Up @@ -685,7 +698,8 @@ def execute(self, sql):
runtime_config=self.runtime_config,
)
qe.serialize2warehouse(
self.config.get_config()["warehousedir"],
self.config.get_config()[
"warehousedir"],
self.runtime_config,
)
self.model_catalog.add_model_wrapper(
Expand Down Expand Up @@ -721,7 +735,8 @@ def execute(self, sql):
runtime_config=self.runtime_config,
)
qe.serialize2warehouse(
self.config.get_config()["warehousedir"],
self.config.get_config()[
"warehousedir"],
self.runtime_config,
)
self.model_catalog.add_model_wrapper(
Expand Down Expand Up @@ -762,7 +777,7 @@ def execute(self, sql):
)

if (
mdl + self.runtime_config["model_suffix"]
mdl + self.runtime_config["model_suffix"]
not in self.model_catalog.model_catalog
):
print("Model " + mdl + " does not exist.")
Expand Down Expand Up @@ -829,10 +844,11 @@ def execute(self, sql):
print(predictions.to_string(index=False)) # max_rows=5

if self.runtime_config["result2file"]:
predictions.to_csv(self.runtime_config["result2file"],header=False, sep=',', index=False, quoting=csv.QUOTE_NONE, quotechar="", escapechar=" ")
predictions.to_csv(self.runtime_config["result2file"], header=False, sep=',',
index=False, quoting=csv.QUOTE_NONE, quotechar="", escapechar=" ")
# print(predictions.to_csv(sep=',', index=False)) # sep='\t'
# with open(self.runtime_config["result2file"],'w') as f:
# out =
# out =
# f.write(predictions.to_string(index=False)) # max_rows=5

if self.runtime_config["b_show_latency"]:
Expand Down Expand Up @@ -883,7 +899,8 @@ def execute(self, sql):
print("device is set to " + value)
else:
if value == "gpu":
print("GPU is not available, use CPU instead")
print(
"GPU is not available, use CPU instead")
value = "cpu"
if value == "cpu":
if self.runtime_config["v"]:
Expand Down Expand Up @@ -923,7 +940,8 @@ def execute(self, sql):
t_start = datetime.now()
if self.runtime_config["b_print_to_screen"]:
for key in self.model_catalog.model_catalog:
print(key.replace(self.runtime_config["model_suffix"], ""))
print(key.replace(
self.runtime_config["model_suffix"], ""))
if self.runtime_config["v"]:
t_end = datetime.now()
time_cost = (t_end - t_start).total_seconds()
Expand Down
22 changes: 20 additions & 2 deletions tests/integration/dbestclient/executor/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,26 @@ def test_memory(self):
self.assertTrue(h.heap())


class BeijingPM25(unittest.TestCase):
def test_simple_model(self):
sqlExecutor = SqlExecutor()
sqlExecutor.execute("set n_epoch=10")
sqlExecutor.execute("set reg_type='mdn'")
sqlExecutor.execute("set density_type='mdn'")
sqlExecutor.execute("set b_grid_search='False'")
sqlExecutor.execute("set csv_split_char=','")
sqlExecutor.execute(
"create table pm25(pm25 real, PRES real) from pm25.csv method uniform size 1000") # scale data num_of_points2.csv
predictions = sqlExecutor.execute(
"select avg(pm25) from pm25 where 1000 <=PRES<= 1020 ")
sqlExecutor.execute("drop table pm25")
# print(predictions)
self.assertFalse(predictions.empty)


if __name__ == "__main__":
unittest.main()
# TestTpcDs().test_simple_model()
# unittest.main()
TestTpcDs().test_simple_model()
# TestTpcDs().test_groupbys_range_no_categorical_gb1()
# TestTpcDs().test_groupbys_range_no_categorical_gb2()
# TestTpcDs().test_groupbys_range_no_categorical_gb1_stratified()
Expand All @@ -546,3 +563,4 @@ def test_memory(self):
# TestTpcDs().test_no_continuous_categorical2_one_model_stratified()
# TestTpcDs().test_plot()
# TestTpcDs().test_memory()
# BeijingPM25().test_simple_model()

0 comments on commit f00892d

Please sign in to comment.