diff --git a/ioutil.c b/ioutil.c index 4e6aa55..301db77 100644 --- a/ioutil.c +++ b/ioutil.c @@ -2,6 +2,8 @@ #include #include "types.h" #include "ioutil.h" +#include "vec.h" +#include #ifndef _WIN32 @@ -31,7 +33,7 @@ FH createfile(const char *path,int secret) int fd; do { fd = open(path,O_WRONLY | O_CREAT | O_TRUNC,secret ? 0600 : 0666); - if (fd == -1) { + if (fd < 0) { if (errno == EINTR) continue; return -1; @@ -45,7 +47,7 @@ int closefile(FH fd) int cret; do { cret = close(fd); - if (cret == -1) { + if (cret < 0) { if (errno == EINTR) continue; return -1; @@ -59,6 +61,107 @@ int createdir(const char *path,int secret) return mkdir(path,secret ? 0700 : 0777); } +int syncwrite(const char *filename,int secret,const u8 *data,size_t datalen) +{ + //fprintf(stderr,"filename = %s\n",filename); + + VEC_STRUCT(,char) tmpname; + size_t fnlen = strlen(filename); + VEC_INIT(tmpname); + VEC_ADDN(tmpname,fnlen + 4 /* ".tmp" */ + 1 /* "\0" */); + memcpy(&VEC_BUF(tmpname,0),filename,fnlen); + strcpy(&VEC_BUF(tmpname,fnlen),".tmp"); + const char *tmpnamestr = &VEC_BUF(tmpname,0); + + //fprintf(stderr,"tmpnamestr = %s\n",tmpnamestr); + + FH f = createfile(tmpnamestr,secret); + if (f == FH_invalid) + return -1; + + if (writeall(f,data,datalen) < 0) { + closefile(f); + remove(tmpnamestr); + return -1; + } + + int sret; + do { + sret = fsync(f); + if (sret < 0) { + if (errno == EINTR) + continue; + + closefile(f); + remove(tmpnamestr); + return -1; + } + } while (0); + + if (closefile(f) < 0) { + remove(tmpnamestr); + return -1; + } + + if (rename(tmpnamestr,filename) < 0) { + remove(tmpnamestr); + return -1; + } + + VEC_STRUCT(,char) dirname; + const char *dirnamestr; + + for (ssize_t x = ((ssize_t)fnlen) - 1;x >= 0;--x) { + if (filename[x] == '/') { + if (x) + --x; + ++x; + VEC_INIT(dirname); + VEC_ADDN(dirname,x + 1); + memcpy(&VEC_BUF(dirname,0),filename,x); + VEC_BUF(dirname,x) = '\0'; + dirnamestr = &VEC_BUF(dirname,0); + goto foundslash; + } + } + /* not found slash, fall back to "." */ + dirnamestr = "."; + +foundslash: + //fprintf(stderr,"dirnamestr = %s\n",dirnamestr); + ; + + int dirf; + do { + dirf = open(dirnamestr,O_RDONLY); + if (dirf < 0) { + if (errno == EINTR) + continue; + + // failed for non-eintr reasons + goto skipdsync; // don't really care enough + } + } while (0); + + do { + sret = fsync(dirf); + if (sret < 0) { + if (errno == EINTR) + continue; + + // failed for non-eintr reasons + break; // don't care + } + } while (0); + + (void) closefile(dirf); // don't care + +skipdsync: + + return 0; +} + + #else int writeall(FH fd,const u8 *data,size_t len) @@ -99,6 +202,49 @@ int createdir(const char *path,int secret) return CreateDirectoryA(path,0) ? 0 : -1; } + + +int syncwrite(const char *filename,int secret,const char *data,size_t datalen) +{ + VEC_STRUCT(,char) tmpname; + size_t fnlen = strlen(filename); + VEC_INIT(tmpname); + VEC_ADDN(tmpname,fnlen + 4 /* ".tmp" */ + 1 /* "\0" */); + memcpy(&VEC_BUF(tmpname,0),filename,fnlen); + strcpy(&VEC_BUF(tmpname,fnlen),".tmp"); + const char *tmpnamestr = &VEC_BUF(tmpname,0); + + FH f = createfile(tmpnamestr,secret) + if (f == FH_invalid) + return -1; + + if (writeall(f,data,datalen) < 0) { + closefile(f); + remove(tmpnamestr); + return -1; + } + + if (FlushFileBuffers(f) == 0) { + closefile(f); + remove(tmpnamestr); + return -1; + } + + if (closefile(f) < 0) { + remove(tmpnamestr); + return -1; + } + + if (MoveFileA(tmpnamestr,filename) == 0) { + remove(tmpnamestr); + return -1; + } + + // can't fsync parent dir on windows so just end here + + return 0; +} + #endif int writetofile(const char *path,const u8 *data,size_t len,int secret) diff --git a/ioutil.h b/ioutil.h index c7a1dab..5244508 100644 --- a/ioutil.h +++ b/ioutil.h @@ -18,3 +18,4 @@ int closefile(FH fd); int writeall(FH,const u8 *data,size_t len); int writetofile(const char *path,const u8 *data,size_t len,int secret); int createdir(const char *path,int secret); +int syncwrite(const char *filename,int secret,const u8 *data,size_t datalen); diff --git a/main.c b/main.c index 85e637c..30c630e 100644 --- a/main.c +++ b/main.c @@ -29,6 +29,8 @@ #include "worker.h" +#include "likely.h" + #ifndef _WIN32 #define FSZ "%zu" #else @@ -182,35 +184,55 @@ static void setpassphrase(const char *pass) static void savecheckpoint(void) { - if (checkpointfile) { - // Open checkpoint file - FILE *checkout = fopen(checkpointfile, "w"); - if (!checkout) { - fprintf(stderr,"cannot open checkpoint file for writing\n"); - exit(1); - } - - // Calculate checkpoint as the difference between original seed and the current seed - u8 checkpoint[SEED_LEN]; - bool carry = 0; - pthread_mutex_lock(&determseed_mutex); - for (int i = 0; i < SEED_LEN; i++) { - checkpoint[i] = determseed[i] - orig_determseed[i] - carry; - carry = checkpoint[i] > determseed[i]; - } - pthread_mutex_unlock(&determseed_mutex); - - // Write checkpoint file - if(fwrite(checkpoint, 1, SEED_LEN, checkout) != SEED_LEN) { - fprintf(stderr,"cannot write to checkpoint file\n"); - exit(1); - } - fclose(checkout); + u8 checkpoint[SEED_LEN]; + bool carry = 0; + pthread_mutex_lock(&determseed_mutex); + for (int i = 0; i < SEED_LEN; i++) { + checkpoint[i] = determseed[i] - orig_determseed[i] - carry; + carry = checkpoint[i] > determseed[i]; } + pthread_mutex_unlock(&determseed_mutex); + + if (syncwrite(checkpointfile,1,checkpoint,SEED_LEN) < 0) { + pthread_mutex_lock(&fout_mutex); + fprintf(stderr,"ERROR: could not save checkpoint\n"); + pthread_mutex_unlock(&fout_mutex); + } +} + +static volatile int checkpointer_endwork = 0; + +static void *checkpointworker(void *arg) +{ + (void) arg; + + struct timespec ts; + memset(&ts,0,sizeof(ts)); + ts.tv_nsec = 100000000; + + struct timespec nowtime; + u64 ilasttime,inowtime; + clock_gettime(CLOCK_MONOTONIC,&nowtime); + ilasttime = (1000000 * (u64)nowtime.tv_sec) + ((u64)nowtime.tv_nsec / 1000); + + while (!unlikely(checkpointer_endwork)) { + + clock_gettime(CLOCK_MONOTONIC,&nowtime); + inowtime = (1000000 * (u64)nowtime.tv_sec) + ((u64)nowtime.tv_nsec / 1000); + + if (inowtime - ilasttime >= 300 * 1000000 /* 5 minutes */) { + savecheckpoint(); + ilasttime = inowtime; + } + } + + savecheckpoint(); + + return 0; } #endif -VEC_STRUCT(threadvec, pthread_t); +VEC_STRUCT(threadvec,pthread_t); #include "filters_inc.inc.h" #include "filters_main.inc.h" @@ -458,6 +480,11 @@ int main(int argc,char **argv) exit(1); } + if (checkpointfile && !deterministic) { + fprintf(stderr,"--checkpoint requires passphrase\n"); + exit(1); + } + if (outfile) { fout = fopen(outfile,!outfileoverwrite ? "a" : "w"); if (!fout) { @@ -543,16 +570,16 @@ int main(int argc,char **argv) numthreads,numthreads == 1 ? "thread" : "threads"); #ifdef PASSPHRASE - memcpy(orig_determseed, determseed, sizeof(determseed)); if (deterministic) { + memcpy(orig_determseed,determseed,sizeof(determseed)); if (!quietflag && numneedgenerate != 1) fprintf(stderr,"CAUTION: avoid using keys generated with same password for unrelated services, as single leaked key may help attacker to regenerate related keys.\n"); if (checkpointfile) { // Read current checkpoint position if file exists - FILE *checkout = fopen(checkpointfile, "r"); + FILE *checkout = fopen(checkpointfile,"r"); if (checkout) { u8 checkpoint[SEED_LEN]; - if(fread(checkpoint, 1, SEED_LEN, checkout) != SEED_LEN) { + if(fread(checkpoint,1,SEED_LEN,checkout) != SEED_LEN) { fprintf(stderr,"failed to read checkpoint file\n"); exit(1); } @@ -634,6 +661,18 @@ int main(int argc,char **argv) perror("pthread_attr_destroy"); } +#if PASSPHRASE + pthread_t checkpoint_thread; + + if (checkpointfile) { + tret = pthread_create(&checkpoint_thread,NULL,checkpointworker,NULL); + if (tret) { + fprintf(stderr,"error while making checkpoint thread: %s\n",strerror(tret)); + exit(1); + } + } +#endif + #ifdef STATISTICS struct timespec nowtime; u64 istarttime,inowtime,ireporttime = 0,elapsedoffset = 0; @@ -643,20 +682,15 @@ int main(int argc,char **argv) } istarttime = (1000000 * (u64)nowtime.tv_sec) + ((u64)nowtime.tv_nsec / 1000); #endif + struct timespec ts; memset(&ts,0,sizeof(ts)); ts.tv_nsec = 100000000; - u16 loopcounter = 0; while (!endwork) { if (numneedgenerate && keysgenerated >= numneedgenerate) { endwork = 1; break; } - loopcounter++; - if (loopcounter >= 3000) { // Save checkpoint every 5 minutes - savecheckpoint(); - loopcounter = 0; - } nanosleep(&ts,0); #ifdef STATISTICS @@ -722,14 +756,18 @@ int main(int argc,char **argv) #endif } -#ifdef PASSPHRASE - savecheckpoint(); -#endif - if (!quietflag) fprintf(stderr,"waiting for threads to finish..."); + for (size_t i = 0;i < VEC_LENGTH(threads);++i) pthread_join(VEC_BUF(threads,i),0); +#ifdef PASSPHRASE + if (checkpointfile) { + checkpointer_endwork = 1; + pthread_join(checkpoint_thread,0); + } +#endif + if (!quietflag) fprintf(stderr," done.\n");