diff --git a/bin/q b/bin/q index fbd58791..56c6fe0e 100755 --- a/bin/q +++ b/bin/q @@ -476,16 +476,18 @@ class TableColumnInferer(object): self.rows = [] self.skip_header = skip_header self.header_row = None + self.header_row_filename = None self.expected_column_count = expected_column_count self.input_delimiter = input_delimiter self.disable_column_type_detection = disable_column_type_detection - def analyze(self, col_vals): + def analyze(self, filename, col_vals): if self.inferred: raise Exception("Already inferred columns") if self.skip_header and self.header_row is None: self.header_row = col_vals + self.header_row_filename = filename else: self.rows.append(col_vals) @@ -905,17 +907,36 @@ class TableCreator(object): mfs = MaterializedFileState(filename,f,self.encoding,dialect,is_stdin) self.materialized_file_dict[filename] = mfs + def _should_skip_extra_headers(self, filenumber, filename, mfs, col_vals): + if not self.skip_header: + return False + + if filenumber == 0: + return False + + header_already_exists = self.column_inferer.header_row is not None + + is_extra_header = self.skip_header and mfs.lines_read == 1 and header_already_exists + + if is_extra_header: + if tuple(self.column_inferer.header_row) != tuple(col_vals): + raise BadHeaderException("Extra header {} in file {} mismatches original header {} from file {}. Table name is {}".format(",".join(col_vals),mfs.filename,",".join(self.column_inferer.header_row),self.column_inferer.header_row_filename,self.filenames_str)) + + return is_extra_header + def _populate(self,dialect,stop_after_analysis=False): total_data_lines_read = 0 # For each match - for filename in self.materialized_file_list: + for filenumber,filename in enumerate(self.materialized_file_list): mfs = self.materialized_file_dict[filename] try: try: for col_vals in mfs.read_file_using_csv(): - self._insert_row(col_vals) + if self._should_skip_extra_headers(filenumber,filename,mfs,col_vals): + continue + self._insert_row(filename, col_vals) if stop_after_analysis and self.column_inferer.inferred: return if mfs.lines_read == 0 and self.skip_header: @@ -937,7 +958,7 @@ class TableCreator(object): if not self.table_created: self.column_inferer.force_analysis() - self._do_create_table() + self._do_create_table(filename) if total_data_lines_read == 0: @@ -960,20 +981,20 @@ class TableCreator(object): self.state = TableCreatorState.FULLY_READ return - def _flush_pre_creation_rows(self): + def _flush_pre_creation_rows(self, filename): for i, col_vals in enumerate(self.pre_creation_rows): if self.skip_header and i == 0: # skip header line continue - self._insert_row(col_vals) + self._insert_row(filename, col_vals) self._flush_inserts() self.pre_creation_rows = [] - def _insert_row(self, col_vals): + def _insert_row(self, filename, col_vals): # If table has not been created yet if not self.table_created: # Try to create it along with another "example" line of data - self.try_to_create_table(col_vals) + self.try_to_create_table(filename, col_vals) # If the table is still not created, then we don't have enough data, just # store the data and return @@ -1069,19 +1090,19 @@ class TableCreator(object): # print self.db.execute_and_fetch(self.db.generate_end_transaction()) self.buffered_inserts = [] - def try_to_create_table(self, col_vals): + def try_to_create_table(self, filename, col_vals): if self.table_created: raise Exception('Table is already created') # Add that line to the column inferer - result = self.column_inferer.analyze(col_vals) + result = self.column_inferer.analyze(filename, col_vals) # If inferer succeeded, if result: - self._do_create_table() + self._do_create_table(filename) else: pass # We don't have enough information for creating the table yet - def _do_create_table(self): + def _do_create_table(self,filename): # Then generate a temp table name self.table_name = self.db.generate_temp_table_name() # Get the column definition dict from the inferer @@ -1101,7 +1122,7 @@ class TableCreator(object): self.db.execute_and_fetch(create_table_stmt) # Mark the table as created self.table_created = True - self._flush_pre_creation_rows() + self._flush_pre_creation_rows(filename) def drop_table(self): if self.table_created: diff --git a/test/test-suite b/test/test-suite index e17afcd1..bc7fc379 100755 --- a/test/test-suite +++ b/test/test-suite @@ -93,6 +93,9 @@ sample_data_with_empty_string_no_header = six.b("\n").join( sample_data_with_header = header_row + six.b("\n") + sample_data_no_header sample_data_with_missing_header_names = six.b("name,value1\n") + sample_data_no_header +def generate_sample_data_with_header(header): + return header + six.b("\n") + sample_data_no_header + sample_quoted_data = six.b('''non_quoted regular_double_quoted double_double_quoted escaped_double_quoted multiline_double_double_quoted multiline_escaped_double_quoted control-value-1 "control-value-2" control-value-3 "control-value-4" control-value-5 "control-value-6" non-quoted-value "this is a quoted value" "this is a ""double double"" quoted value" "this is an escaped \\"quoted value\\"" "this is a double double quoted ""multiline @@ -1422,6 +1425,109 @@ class BasicTests(AbstractQTestCase): self.cleanup(tmpfile) +class MultiHeaderTests(AbstractQTestCase): + def test_output_header_when_multiple_input_headers_exist(self): + TMPFILE_COUNT = 5 + tmpfiles = [self.create_file_with_data(sample_data_with_header) for x in range(TMPFILE_COUNT)] + + tmpfilenames = "+".join(map(lambda x:x.name, tmpfiles)) + + cmd = '../bin/q -d , "select name,value1,value2 from %s order by name" -H -O' % tmpfilenames + retcode, o, e = run_command(cmd) + + self.assertEqual(retcode, 0) + self.assertEqual(len(o), TMPFILE_COUNT*3+1) + self.assertEqual(o[0], six.b("name,value1,value2")) + + for i in range (TMPFILE_COUNT): + self.assertEqual(o[1+i],sample_data_rows[0]) + for i in range (TMPFILE_COUNT): + self.assertEqual(o[TMPFILE_COUNT+1+i],sample_data_rows[1]) + for i in range (TMPFILE_COUNT): + self.assertEqual(o[TMPFILE_COUNT*2+1+i],sample_data_rows[2]) + + for oi in o[1:]: + self.assertTrue(six.b('name') not in oi) + + for i in range(TMPFILE_COUNT): + self.cleanup(tmpfiles[i]) + + def test_output_header_when_extra_header_column_names_are_different(self): + tmpfile1 = self.create_file_with_data(sample_data_with_header) + tmpfile2 = self.create_file_with_data(generate_sample_data_with_header(six.b('othername,value1,value2'))) + + cmd = '../bin/q -d , "select name,value1,value2 from %s+%s order by name" -H -O' % (tmpfile1.name,tmpfile2.name) + retcode, o, e = run_command(cmd) + + self.assertEqual(retcode, 35) + self.assertEqual(len(o), 0) + self.assertEqual(len(e), 1) + self.assertTrue(e[0].startswith(six.b("Bad header row:"))) + + self.cleanup(tmpfile1) + self.cleanup(tmpfile2) + + def test_output_header_when_extra_header_has_different_number_of_columns(self): + tmpfile1 = self.create_file_with_data(sample_data_with_header) + tmpfile2 = self.create_file_with_data(generate_sample_data_with_header(six.b('name,value1'))) + + cmd = '../bin/q -d , "select name,value1,value2 from %s+%s order by name" -H -O' % (tmpfile1.name,tmpfile2.name) + retcode, o, e = run_command(cmd) + + self.assertEqual(retcode, 35) + self.assertEqual(len(o), 0) + self.assertEqual(len(e), 1) + self.assertTrue(e[0].startswith(six.b("Bad header row:"))) + + self.cleanup(tmpfile1) + self.cleanup(tmpfile2) + + def test_output_header_when_extra_header_has_different_number_of_columns2(self): + original_header = header_row + tmpfile1 = self.create_file_with_data(sample_data_with_header) + different_header = six.b('name,value1,value2,value3') + tmpfile2 = self.create_file_with_data(generate_sample_data_with_header(different_header)) + + SELECT_table_name = '%s+%s' % (tmpfile1.name,tmpfile2.name) + cmd = '../bin/q -d , "select name,value1,value2 from %s order by name" -H -O' % (SELECT_table_name) + retcode, o, e = run_command(cmd) + + self.assertEqual(retcode, 35) + self.assertEqual(len(o), 0) + self.assertEqual(len(e), 1) + expected_message = six.b('Bad header row: Extra header %s in file %s mismatches original header %s from file %s. Table name is %s') % \ + (different_header,six.b(tmpfile2.name),original_header,six.b(tmpfile1.name),six.b(SELECT_table_name)) + + self.assertEqual(e[0],expected_message) + + self.cleanup(tmpfile1) + self.cleanup(tmpfile2) + + # Not the best behavior, this means that if the first file in additional files contains exactly the + # same content as the original header, then q would skip this line instead of failing. + # Extremely rare case, and for any table with numeric values, this is not an issue, since column names + # cannot be numbers. + def test_output_header_when_additional_files_dont_have_a_header(self): + original_header = header_row + tmpfile1 = self.create_file_with_data(sample_data_with_header) + tmpfile2 = self.create_file_with_data(sample_data_no_header) + + SELECT_table_name = '%s+%s' % (tmpfile1.name,tmpfile2.name) + cmd = '../bin/q -d , "select name,value1,value2 from %s order by name" -H -O' % (SELECT_table_name) + retcode, o, e = run_command(cmd) + + self.assertEqual(retcode, 35) + self.assertEqual(len(o), 0) + self.assertEqual(len(e), 1) + expected_message = six.b('Bad header row: Extra header %s in file %s mismatches original header %s from file %s. Table name is %s') % \ + (sample_data_rows[0],six.b(tmpfile2.name),original_header,six.b(tmpfile1.name),six.b(SELECT_table_name)) + + self.assertEqual(e[0],expected_message) + + self.cleanup(tmpfile1) + self.cleanup(tmpfile2) + + class ParsingModeTests(AbstractQTestCase): def test_strict_mode_column_count_mismatch_error(self): @@ -2351,7 +2457,8 @@ def suite(): formatting = tl.loadTestsFromTestCase(FormattingTests) basic_module_stuff = tl.loadTestsFromTestCase(BasicModuleTests) save_db_to_disk_tests = tl.loadTestsFromTestCase(SaveDbToDiskTests) - return unittest.TestSuite([basic_module_stuff, basic_stuff, parsing_mode, sql, formatting,save_db_to_disk_tests]) + multi_header_tests = tl.loadTestsFromTestCase(MultiHeaderTests) + return unittest.TestSuite([basic_module_stuff, basic_stuff, parsing_mode, sql, formatting,save_db_to_disk_tests,multi_header_tests]) if __name__ == '__main__': if len(sys.argv) > 1: