Skip to content

Commit

Permalink
Fix output header issue for multi-file tables #212
Browse files Browse the repository at this point in the history
  • Loading branch information
harelba committed Dec 21, 2019
1 parent 5c96ad2 commit a603ab6
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 14 deletions.
47 changes: 34 additions & 13 deletions bin/q
Original file line number Diff line number Diff line change
Expand Up @@ -476,16 +476,18 @@ class TableColumnInferer(object):
self.rows = []
self.skip_header = skip_header
self.header_row = None
self.header_row_filename = None
self.expected_column_count = expected_column_count
self.input_delimiter = input_delimiter
self.disable_column_type_detection = disable_column_type_detection

def analyze(self, col_vals):
def analyze(self, filename, col_vals):
if self.inferred:
raise Exception("Already inferred columns")

if self.skip_header and self.header_row is None:
self.header_row = col_vals
self.header_row_filename = filename
else:
self.rows.append(col_vals)

Expand Down Expand Up @@ -905,17 +907,36 @@ class TableCreator(object):
mfs = MaterializedFileState(filename,f,self.encoding,dialect,is_stdin)
self.materialized_file_dict[filename] = mfs

def _should_skip_extra_headers(self, filenumber, filename, mfs, col_vals):
if not self.skip_header:
return False

if filenumber == 0:
return False

header_already_exists = self.column_inferer.header_row is not None

is_extra_header = self.skip_header and mfs.lines_read == 1 and header_already_exists

if is_extra_header:
if tuple(self.column_inferer.header_row) != tuple(col_vals):
raise BadHeaderException("Extra header {} in file {} mismatches original header {} from file {}. Table name is {}".format(",".join(col_vals),mfs.filename,",".join(self.column_inferer.header_row),self.column_inferer.header_row_filename,self.filenames_str))

return is_extra_header

def _populate(self,dialect,stop_after_analysis=False):
total_data_lines_read = 0

# For each match
for filename in self.materialized_file_list:
for filenumber,filename in enumerate(self.materialized_file_list):
mfs = self.materialized_file_dict[filename]

try:
try:
for col_vals in mfs.read_file_using_csv():
self._insert_row(col_vals)
if self._should_skip_extra_headers(filenumber,filename,mfs,col_vals):
continue
self._insert_row(filename, col_vals)
if stop_after_analysis and self.column_inferer.inferred:
return
if mfs.lines_read == 0 and self.skip_header:
Expand All @@ -937,7 +958,7 @@ class TableCreator(object):

if not self.table_created:
self.column_inferer.force_analysis()
self._do_create_table()
self._do_create_table(filename)


if total_data_lines_read == 0:
Expand All @@ -960,20 +981,20 @@ class TableCreator(object):
self.state = TableCreatorState.FULLY_READ
return

def _flush_pre_creation_rows(self):
def _flush_pre_creation_rows(self, filename):
for i, col_vals in enumerate(self.pre_creation_rows):
if self.skip_header and i == 0:
# skip header line
continue
self._insert_row(col_vals)
self._insert_row(filename, col_vals)
self._flush_inserts()
self.pre_creation_rows = []

def _insert_row(self, col_vals):
def _insert_row(self, filename, col_vals):
# If table has not been created yet
if not self.table_created:
# Try to create it along with another "example" line of data
self.try_to_create_table(col_vals)
self.try_to_create_table(filename, col_vals)

# If the table is still not created, then we don't have enough data, just
# store the data and return
Expand Down Expand Up @@ -1069,19 +1090,19 @@ class TableCreator(object):
# print self.db.execute_and_fetch(self.db.generate_end_transaction())
self.buffered_inserts = []

def try_to_create_table(self, col_vals):
def try_to_create_table(self, filename, col_vals):
if self.table_created:
raise Exception('Table is already created')

# Add that line to the column inferer
result = self.column_inferer.analyze(col_vals)
result = self.column_inferer.analyze(filename, col_vals)
# If inferer succeeded,
if result:
self._do_create_table()
self._do_create_table(filename)
else:
pass # We don't have enough information for creating the table yet

def _do_create_table(self):
def _do_create_table(self,filename):
# Then generate a temp table name
self.table_name = self.db.generate_temp_table_name()
# Get the column definition dict from the inferer
Expand All @@ -1101,7 +1122,7 @@ class TableCreator(object):
self.db.execute_and_fetch(create_table_stmt)
# Mark the table as created
self.table_created = True
self._flush_pre_creation_rows()
self._flush_pre_creation_rows(filename)

def drop_table(self):
if self.table_created:
Expand Down
109 changes: 108 additions & 1 deletion test/test-suite
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ sample_data_with_empty_string_no_header = six.b("\n").join(
sample_data_with_header = header_row + six.b("\n") + sample_data_no_header
sample_data_with_missing_header_names = six.b("name,value1\n") + sample_data_no_header

def generate_sample_data_with_header(header):
return header + six.b("\n") + sample_data_no_header

sample_quoted_data = six.b('''non_quoted regular_double_quoted double_double_quoted escaped_double_quoted multiline_double_double_quoted multiline_escaped_double_quoted
control-value-1 "control-value-2" control-value-3 "control-value-4" control-value-5 "control-value-6"
non-quoted-value "this is a quoted value" "this is a ""double double"" quoted value" "this is an escaped \\"quoted value\\"" "this is a double double quoted ""multiline
Expand Down Expand Up @@ -1422,6 +1425,109 @@ class BasicTests(AbstractQTestCase):
self.cleanup(tmpfile)


class MultiHeaderTests(AbstractQTestCase):
def test_output_header_when_multiple_input_headers_exist(self):
TMPFILE_COUNT = 5
tmpfiles = [self.create_file_with_data(sample_data_with_header) for x in range(TMPFILE_COUNT)]

tmpfilenames = "+".join(map(lambda x:x.name, tmpfiles))

cmd = '../bin/q -d , "select name,value1,value2 from %s order by name" -H -O' % tmpfilenames
retcode, o, e = run_command(cmd)

self.assertEqual(retcode, 0)
self.assertEqual(len(o), TMPFILE_COUNT*3+1)
self.assertEqual(o[0], six.b("name,value1,value2"))

for i in range (TMPFILE_COUNT):
self.assertEqual(o[1+i],sample_data_rows[0])
for i in range (TMPFILE_COUNT):
self.assertEqual(o[TMPFILE_COUNT+1+i],sample_data_rows[1])
for i in range (TMPFILE_COUNT):
self.assertEqual(o[TMPFILE_COUNT*2+1+i],sample_data_rows[2])

for oi in o[1:]:
self.assertTrue(six.b('name') not in oi)

for i in range(TMPFILE_COUNT):
self.cleanup(tmpfiles[i])

def test_output_header_when_extra_header_column_names_are_different(self):
tmpfile1 = self.create_file_with_data(sample_data_with_header)
tmpfile2 = self.create_file_with_data(generate_sample_data_with_header(six.b('othername,value1,value2')))

cmd = '../bin/q -d , "select name,value1,value2 from %s+%s order by name" -H -O' % (tmpfile1.name,tmpfile2.name)
retcode, o, e = run_command(cmd)

self.assertEqual(retcode, 35)
self.assertEqual(len(o), 0)
self.assertEqual(len(e), 1)
self.assertTrue(e[0].startswith(six.b("Bad header row:")))

self.cleanup(tmpfile1)
self.cleanup(tmpfile2)

def test_output_header_when_extra_header_has_different_number_of_columns(self):
tmpfile1 = self.create_file_with_data(sample_data_with_header)
tmpfile2 = self.create_file_with_data(generate_sample_data_with_header(six.b('name,value1')))

cmd = '../bin/q -d , "select name,value1,value2 from %s+%s order by name" -H -O' % (tmpfile1.name,tmpfile2.name)
retcode, o, e = run_command(cmd)

self.assertEqual(retcode, 35)
self.assertEqual(len(o), 0)
self.assertEqual(len(e), 1)
self.assertTrue(e[0].startswith(six.b("Bad header row:")))

self.cleanup(tmpfile1)
self.cleanup(tmpfile2)

def test_output_header_when_extra_header_has_different_number_of_columns2(self):
original_header = header_row
tmpfile1 = self.create_file_with_data(sample_data_with_header)
different_header = six.b('name,value1,value2,value3')
tmpfile2 = self.create_file_with_data(generate_sample_data_with_header(different_header))

SELECT_table_name = '%s+%s' % (tmpfile1.name,tmpfile2.name)
cmd = '../bin/q -d , "select name,value1,value2 from %s order by name" -H -O' % (SELECT_table_name)
retcode, o, e = run_command(cmd)

self.assertEqual(retcode, 35)
self.assertEqual(len(o), 0)
self.assertEqual(len(e), 1)
expected_message = six.b('Bad header row: Extra header %s in file %s mismatches original header %s from file %s. Table name is %s') % \
(different_header,six.b(tmpfile2.name),original_header,six.b(tmpfile1.name),six.b(SELECT_table_name))

self.assertEqual(e[0],expected_message)

self.cleanup(tmpfile1)
self.cleanup(tmpfile2)

# Not the best behavior, this means that if the first file in additional files contains exactly the
# same content as the original header, then q would skip this line instead of failing.
# Extremely rare case, and for any table with numeric values, this is not an issue, since column names
# cannot be numbers.
def test_output_header_when_additional_files_dont_have_a_header(self):
original_header = header_row
tmpfile1 = self.create_file_with_data(sample_data_with_header)
tmpfile2 = self.create_file_with_data(sample_data_no_header)

SELECT_table_name = '%s+%s' % (tmpfile1.name,tmpfile2.name)
cmd = '../bin/q -d , "select name,value1,value2 from %s order by name" -H -O' % (SELECT_table_name)
retcode, o, e = run_command(cmd)

self.assertEqual(retcode, 35)
self.assertEqual(len(o), 0)
self.assertEqual(len(e), 1)
expected_message = six.b('Bad header row: Extra header %s in file %s mismatches original header %s from file %s. Table name is %s') % \
(sample_data_rows[0],six.b(tmpfile2.name),original_header,six.b(tmpfile1.name),six.b(SELECT_table_name))

self.assertEqual(e[0],expected_message)

self.cleanup(tmpfile1)
self.cleanup(tmpfile2)


class ParsingModeTests(AbstractQTestCase):

def test_strict_mode_column_count_mismatch_error(self):
Expand Down Expand Up @@ -2351,7 +2457,8 @@ def suite():
formatting = tl.loadTestsFromTestCase(FormattingTests)
basic_module_stuff = tl.loadTestsFromTestCase(BasicModuleTests)
save_db_to_disk_tests = tl.loadTestsFromTestCase(SaveDbToDiskTests)
return unittest.TestSuite([basic_module_stuff, basic_stuff, parsing_mode, sql, formatting,save_db_to_disk_tests])
multi_header_tests = tl.loadTestsFromTestCase(MultiHeaderTests)
return unittest.TestSuite([basic_module_stuff, basic_stuff, parsing_mode, sql, formatting,save_db_to_disk_tests,multi_header_tests])

if __name__ == '__main__':
if len(sys.argv) > 1:
Expand Down

0 comments on commit a603ab6

Please sign in to comment.