>From d6d6e11fe399495fc1e7834ebd3e0cec94a7e044 Mon Sep 17 00:00:00 2001 From: Assaf Gordon Date: Wed, 6 Mar 2013 18:25:49 -0500 Subject: [PATCH 2/2] shuf: use reservoir-sampling when possible Reservoir Sampling enables selecting K random lines from a very large (or unknown-sized) input: http://en.wikipedia.org/wiki/Reservoir_sampling * src/shuf.c: Use reservoir-sampling when the number of output lines is known (by using '-n X' parameter). read_input_reservoir_sampling() - read lines from input file, and keep only K lines in memory, replacing lines with decreasing probability. write_permuted_output_reservoir() - output permuted reservoir lines. main() - if the number of lines is known, use reservoir-sampling instead of reading entire input file. --- src/shuf.c | 144 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 files changed, 138 insertions(+), 6 deletions(-) diff --git a/src/shuf.c b/src/shuf.c index 71ac3e6..7c5d16c 100644 --- a/src/shuf.c +++ b/src/shuf.c @@ -25,6 +25,7 @@ #include "error.h" #include "fadvise.h" #include "getopt.h" +#include "linebuffer.h" #include "quote.h" #include "quotearg.h" #include "randint.h" @@ -81,7 +82,8 @@ With no FILE, or when FILE is -, read standard input.\n\ non-character as a pseudo short option, starting with CHAR_MAX + 1. */ enum { - RANDOM_SOURCE_OPTION = CHAR_MAX + 1 + RANDOM_SOURCE_OPTION = CHAR_MAX + 1, + DEV_DEBUG_OPTION }; static struct option const long_opts[] = @@ -92,11 +94,31 @@ static struct option const long_opts[] = {"output", required_argument, NULL, 'o'}, {"random-source", required_argument, NULL, RANDOM_SOURCE_OPTION}, {"zero-terminated", no_argument, NULL, 'z'}, + {"-debug", no_argument, NULL, DEV_DEBUG_OPTION}, {GETOPT_HELP_OPTION_DECL}, {GETOPT_VERSION_OPTION_DECL}, {0, 0, 0, 0}, }; +/* debugging for developers. Enables devmsg(). */ +static bool dev_debug = false; + +/* Like error(0, 0, ...), but without an implicit newline. + Also a noop unless the global DEV_DEBUG is set. + TODO: Replace with variadic macro in system.h or + move to a separate module. */ +static inline void +devmsg (char const *fmt, ...) +{ + if (dev_debug) + { + va_list ap; + va_start (ap, fmt); + vfprintf (stderr, fmt, ap); + va_end (ap); + } +} + static bool input_numbers_option_used (size_t lo_input, size_t hi_input) { @@ -135,6 +157,87 @@ next_line (char *line, char eolbyte, size_t n) return p + 1; } +static size_t +read_input_reservoir_sampling (FILE *in, char eolbyte, size_t k, + struct randint_source *s, + struct linebuffer **out_rsrv) +{ + size_t n_lines=0; + struct linebuffer line; + struct linebuffer *rsrv = XCALLOC (k, struct linebuffer); /* init reservoir*/ + + devmsg ("--reservoir_sampling--\n"); + + initbuffer (&line); + while (readlinebuffer_delim (&line, in, eolbyte)!=NULL) + { + if ( n_lines < k ) + { + /* Read first K lines into reservoir */ + + if (dev_debug) + { + fprintf (stderr,"filling reservoir, input line %zu of %zu: '", + n_lines+1, k); + fwrite (line.buffer, sizeof (char), line.length-1, stderr); + fprintf (stderr, "'\n"); + } + + rsrv[n_lines] = line; + initbuffer (&line); /* next line-read will allocate a new buffer */ + + + } + else + { + /* Read the rest of the lines, with decreasing probability of updating + the reservoir */ + randint j = randint_choose (s, n_lines+1); + if ( j < k ) + { + if (dev_debug) + { + fprintf (stderr,"Replacing reservoir sample %zu with " \ + "line %zu '", j, n_lines); + fwrite (line.buffer, sizeof (char), line.length-1, stderr); + fprintf (stderr, "'\n"); + } + + rsrv[j] = line; + initbuffer (&line);/* next line-read will allocate a new buffer */ + + } + } + + ++n_lines; + } + freebuffer (&line); + + /* no more input lines, or an input error */ + if (ferror (in)) + error (EXIT_FAILURE, errno, _("read error")); + + *out_rsrv = rsrv; + return MIN (k, n_lines); +} + +static int +write_permuted_output_reservoir (size_t n_lines, struct linebuffer *lines, + size_t const *permutation) +{ + size_t i; + + for (i = 0; i < n_lines; i++) + { + const struct linebuffer *p = &lines[permutation[i]]; + if (fwrite (p->buffer, sizeof (char), + p->length, stdout) != p->length) + return -1; + } + + return 0; +} + /* Read data from file IN. Input lines are delimited by EOLBYTE; silently append a trailing EOLBYTE if the file ends in some other byte. Store a pointer to the resulting array of lines into *PLINE. @@ -209,14 +312,17 @@ main (int argc, char **argv) char *random_source = NULL; char eolbyte = '\n'; char **input_lines = NULL; + bool use_reservoir_sampling = false; int optc; int n_operands; char **operand; size_t n_lines; - char **line; + char **line = NULL; + struct linebuffer *reservoir = NULL; struct randint_source *randint_source; size_t *permutation; + int i; initialize_main (&argc, &argv); set_program_name (argv[0]); @@ -295,6 +401,10 @@ main (int argc, char **argv) eolbyte = '\0'; break; + case DEV_DEBUG_OPTION: + dev_debug = true; + break; + case_GETOPT_HELP_CHAR; case_GETOPT_VERSION_CHAR (PROGRAM_NAME, AUTHORS); default: @@ -341,8 +451,16 @@ main (int argc, char **argv) fadvise (stdin, FADVISE_SEQUENTIAL); - n_lines = read_input (stdin, eolbyte, &input_lines); - line = input_lines; + if (head_lines != SIZE_MAX) + { + use_reservoir_sampling = true; + n_lines = SIZE_MAX; /* unknown number of input lines, for now */ + } + else + { + n_lines = read_input (stdin, eolbyte, &input_lines); + line = input_lines; + } } head_lines = MIN (head_lines, n_lines); @@ -352,6 +470,15 @@ main (int argc, char **argv) if (! randint_source) error (EXIT_FAILURE, errno, "%s", quotearg_colon (random_source)); + if (use_reservoir_sampling) + { + /* Instead of reading the entire file into 'line', + use reservoir-sampling to store just "head_lines" random lines. */ + n_lines = read_input_reservoir_sampling (stdin, eolbyte, + head_lines, randint_source, + &reservoir); + } + /* Close stdin now, rather than earlier, so that randint_all_new doesn't have to worry about opening something other than stdin. */ @@ -363,8 +490,13 @@ main (int argc, char **argv) if (outfile && ! freopen (outfile, "w", stdout)) error (EXIT_FAILURE, errno, "%s", quotearg_colon (outfile)); - if (write_permuted_output (head_lines, line, lo_input, permutation, eolbyte) - != 0) + + if (use_reservoir_sampling) + i = write_permuted_output_reservoir (n_lines, reservoir, permutation); + else + i = write_permuted_output (head_lines, line, lo_input, + permutation, eolbyte); + if (i != 0) error (EXIT_FAILURE, errno, _("write error")); #ifdef lint -- 1.7.7.4