diff --git a/src/mains/psi_demo.cpp b/src/mains/psi_demo.cpp index b86db9e..6b44a4d 100644 --- a/src/mains/psi_demo.cpp +++ b/src/mains/psi_demo.cpp @@ -16,7 +16,7 @@ int32_t psi_demonstrator(int32_t argc, char** argv) { double epsilon=1.2; uint64_t bytes_sent=0, bytes_received=0, mbfac; uint32_t nelements=0, elebytelen=16, symsecbits=128, intersect_size = 0, i, j, ntasks=1, - pnelements, *elebytelens, *res_bytelens; + pnelements, *elebytelens, *res_bytelens, nclients = 2; uint16_t port=7766; uint8_t **elements, **intersection; bool detailed_timings=false; @@ -32,7 +32,12 @@ int32_t psi_demonstrator(int32_t argc, char** argv) { read_psi_demo_options(&argc, &argv, &role, &protocol, &filename, &address, &nelements, &detailed_timings); if(role == SERVER) { - listen(address.c_str(), port, sockfd.data(), ntasks); + if(protocol == TTP) { + sockfd.resize(nclients); + listen(address.c_str(), port, sockfd.data(), nclients); + } + else + listen(address.c_str(), port, sockfd.data(), ntasks); } else { for(i = 0; i < ntasks; i++) connect(address.c_str(), port, sockfd[i]); @@ -41,13 +46,13 @@ int32_t psi_demonstrator(int32_t argc, char** argv) { gettimeofday(&t_start, NULL); //read in files and get elements and byte-length from there - read_elements(&elements, &elebytelens, &nelements, filename); if(detailed_timings) { gettimeofday(&t_end, NULL); } - pnelements = exchange_information(nelements, elebytelen, symsecbits, ntasks, protocol, sockfd[0]); + if(protocol != TTP) + pnelements = exchange_information(nelements, elebytelen, symsecbits, ntasks, protocol, sockfd[0]); //cout << "Performing private set-intersection between " << nelements << " and " << pnelements << " element sets" << endl; if(detailed_timings) { @@ -62,7 +67,9 @@ int32_t psi_demonstrator(int32_t argc, char** argv) { &crypto, sockfd.data(), ntasks); break; case TTP: - ///ttppsi(role, nelements, elebytelen, elements, &intersection, &crypto, sockfd.data(), nclients, cardinality); break; + intersect_size = ttppsi(role, nelements, elebytelens, elements, &intersection, &res_bytelens, + &crypto, sockfd.data(), ntasks); + break; case DH_ECC: intersect_size = dhpsi(role, nelements, pnelements, elebytelens, elements, &intersection, &res_bytelens, &crypto, sockfd.data(), ntasks); diff --git a/src/naive-hashing/naive-psi.cpp b/src/naive-hashing/naive-psi.cpp index 303f50d..8dd8041 100644 --- a/src/naive-hashing/naive-psi.cpp +++ b/src/naive-hashing/naive-psi.cpp @@ -73,8 +73,7 @@ uint32_t naivepsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx ectx.eles.nelements = neles; ectx.eles.output = hashes; ectx.eles.perm = perm; - - ectx.hctx.symcrypt = crypt_env; + ectx.sctx.symcrypt = crypt_env; run_task(ntasks, ectx, hash); diff --git a/src/pk-based/dh-psi.cpp b/src/pk-based/dh-psi.cpp index 92d34f7..fb37446 100644 --- a/src/pk-based/dh-psi.cpp +++ b/src/pk-based/dh-psi.cpp @@ -71,7 +71,7 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c ectx.eles.nelements = neles; ectx.eles.outbytelen = hash_bytes; ectx.eles.perm = perm; - ectx.hctx.symcrypt = crypt_env; + ectx.sctx.symcrypt = crypt_env; #ifdef DEBUG @@ -86,14 +86,14 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c ectx.eles.outbytelen = fe_bytes; ectx.eles.output = encrypted_eles; ectx.eles.hasvarbytelen = false; - ectx.ectx.field = field; - ectx.ectx.exponent = exponent; - ectx.ectx.sample = true; + ectx.actx.field = field; + ectx.actx.exponent = exponent; + ectx.actx.sample = true; #ifdef DEBUG cout << "Hash and encrypting my elements" << endl; #endif - run_task(ntasks, ectx, encrypt); + run_task(ntasks, ectx, asym_encrypt); peles = (uint8_t*) malloc(sizeof(uint8_t) * pneles * fe_bytes); @@ -110,13 +110,13 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c ectx.eles.fixedbytelen = fe_bytes; ectx.eles.outbytelen = fe_bytes; ectx.eles.hasvarbytelen = false; - ectx.ectx.exponent = exponent; - ectx.ectx.sample = false; + ectx.actx.exponent = exponent; + ectx.actx.sample = false; #ifdef DEBUG cout << "Encrypting partners elements" << endl; #endif - run_task(ntasks, ectx, encrypt); + run_task(ntasks, ectx, asym_encrypt); /* if only the cardinality should be computed, permute the elements randomly again. Otherwise don't permute */ if(cardinality) { @@ -136,7 +136,7 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c ectx.eles.outbytelen = hash_bytes; ectx.eles.hasvarbytelen = false; ectx.eles.perm = cardinality_perm; - ectx.hctx.symcrypt = crypt_env; + ectx.sctx.symcrypt = crypt_env; #ifdef DEBUG cout << "Hashing elements" << endl; diff --git a/src/server-aided/sapsi.cpp b/src/server-aided/sapsi.cpp index 922b7c0..c4c4241 100644 --- a/src/server-aided/sapsi.cpp +++ b/src/server-aided/sapsi.cpp @@ -6,8 +6,8 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) { CSocket* sockfds = socket;//(CSocket*) malloc(sizeof(CSocket) * nclients); uint32_t* neles = (uint32_t*) malloc(sizeof(uint32_t) * nclients); uint8_t** csets = (uint8_t**) malloc(sizeof(uint8_t*) * nclients); - uint8_t* intersect; - uint32_t temp, maskbytelen, intersectsize, minset, i; + uint32_t temp, maskbytelen, intersectsize, minset, i, j; + CBitVector* intersection = new CBitVector[nclients]; #ifndef BATCH cout << "Connections with all " << nclients << " clients established" << endl; @@ -19,16 +19,17 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) { sockfds[i].Receive(&temp, sizeof(uint32_t)); if(i == 0) { maskbytelen = temp; minset = neles[i];} if(neles[i] < minset) minset = neles[i]; - assert(maskbytelen == temp); #ifndef BATCH cout << "Client " << i << " holds " << neles[i] << " elements of length " << (temp * 8) << "-bit" << endl; #endif + intersection[i].ResizeinBytes(ceil_divide(neles[i], 8)); + intersection[i].Reset(); + assert(maskbytelen == temp); + } #ifndef BATCH cout <<"Receiving the client's elements" << endl; #endif - /* Allocate sufficient size for the intersecting elements */ - intersect = (uint8_t*) malloc(sizeof(uint8_t*) * minset * maskbytelen); /* Receive the permuted and masked sets of all clients */ for(i = 0; i < nclients; i++) { @@ -40,7 +41,10 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) { cout << "Computing intersection for the clients" << endl; #endif /* Compute Intersection */ - intersectsize = compute_intersection(nclients, neles, csets, intersect, maskbytelen); + intersectsize = compute_intersection(nclients, neles, csets, intersection, maskbytelen); + + + /* Enter at which position an intersection was found */ #ifndef BATCH cout << "sending all " << intersectsize << " intersecting elements to the clients" << endl; #endif @@ -48,7 +52,7 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) { for(i = 0; i < nclients; i++) { sockfds[i].Send(&intersectsize, sizeof(uint32_t)); if(!cardinality) - sockfds[i].Send(intersect, intersectsize * maskbytelen); + sockfds[i].Send(intersection[i].GetArr(), ceil_divide(neles[i], 8)); } /* Cleanup */ @@ -60,7 +64,7 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) { * for the n-party case a BF-based approach makes more sense. */ //TODO currently only works for 128 bit masks -uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, uint8_t* intersect, uint32_t entrybytelen) { +uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, CBitVector* intersection, uint32_t entrybytelen) { // Create the GHashTable GHashTable *map = NULL, *tmpmap = NULL; GHashTableIter iter; @@ -72,9 +76,10 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset NULL // cleanup value ); - uint32_t i, j, intersectsize, ctr = 0; - uint64_t* tmpval = (uint64_t*) malloc(sizeof(uint64_t)); + uint32_t i, j, intersectsize, ctr = 0, k; + uint64_t* tmpval; uint64_t* tmpkey = (uint64_t*) malloc(sizeof(uint64_t)); + uint64_t* query; #ifndef BATCH cout << "Inserting the items into the hash table " << endl; #endif @@ -83,7 +88,10 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset #ifdef DEBUG cout << "Inserted item: " << (hex) << ((uint64_t*) csets[0])[2*i] << " "<< ((uint64_t*) csets[0])[2*i+1] << (dec) << endl; #endif - g_hash_table_insert(map,(void*) &((uint64_t*)csets[0])[2*i], &(((uint64_t*)csets[0])[2*i+1])); + tmpval = (uint64_t*) malloc(2*sizeof(uint64_t)); + tmpval[0] = (((uint64_t*)csets[0])[2*i+1]); + tmpval[1] = i; + g_hash_table_insert(map,(void*) &((uint64_t*)csets[0])[2*i], tmpval);//&(((uint64_t*)csets[0])[2*i+1])); } #ifdef DEBUG g_hash_table_foreach( map, printKeyValue, NULL ); @@ -108,11 +116,17 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset cout << "Checking for Key: " << (hex) << ((uint64_t*) csets[i])[2*j] << " "<< ((uint64_t*) csets[i])[2*j+1] << (dec) << endl; #endif if(g_hash_table_lookup_extended(map, (void*) &(((uint64_t*)csets[i])[2*j]), - NULL, (void**) &tmpval) && (*tmpval == ((uint64_t*)csets[i])[2*j+1])) { + NULL, (void**) &query) && (*query == ((uint64_t*)csets[i])[2*j+1])) { #ifdef DEBUG cout << "Key was found" << endl; #endif - g_hash_table_insert(tmpmap,(void*) &(((uint64_t*)csets[i])[2*j]),&(((uint64_t*)csets[i])[2*j+1])); + tmpval = (uint64_t*) malloc((i+2)*sizeof(uint64_t)); + tmpval[0] = (((uint64_t*)csets[i])[2*j+1]); + for(k = 1; k < i+1; k++) { + tmpval[k] = query[k]; + } + tmpval[i+1] = j; + g_hash_table_insert(tmpmap,(void*) &(((uint64_t*)csets[i])[2*j]), tmpval);//&(((uint64_t*)csets[i])[2*j+1])); } else { #ifdef DEBUG @@ -143,8 +157,9 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset #ifdef DEBUG cout << (hex) << tmpkey[0] << " " << tmpval[0] << (dec)<< endl; #endif - ((uint64_t*) intersect)[ctr++] = tmpkey[0]; - ((uint64_t*) intersect)[ctr++] = tmpval[0]; + for(i = 0; i < nclients; i++) { + intersection[i].SetBit(tmpval[i+1], 1); + } } gettimeofday(&end, NULL); #ifdef TIMING @@ -162,69 +177,54 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset } -uint32_t client_routine(uint32_t neles, uint32_t elebytelen, uint8_t* elements, - uint8_t** result, crypto* crypt, CSocket* socket, bool cardinality) { - uint32_t maskbytelen = 16, intersectsize, i, j; - uint8_t* masks = (uint8_t*) malloc(sizeof(uint8_t) * neles * maskbytelen); - uint8_t* intersect = (uint8_t*) malloc(sizeof(uint8_t) * neles * maskbytelen); - uint32_t* perm; - uint32_t* invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles); - - uint32_t* tmpval = (uint32_t*) malloc(sizeof(uint32_t)); - GHashTable *map; - - //crypto crypto(symsecbits, (uint8_t*) const_seed); - -// cout << "Starting client with " << neles << " elements of " << (8*elebytelen) << "-bit length with server " -// << address << ":" << port << endl; - - CSocket* sockfd = socket; - //connect(address, port, sockfd); - sockfd->Send((uint8_t*) &neles, sizeof(uint32_t)); - sockfd->Send((uint8_t*) &maskbytelen, sizeof(uint32_t)); +uint32_t client_routine(uint32_t neles, task_ctx ectx, uint32_t* matches, + crypto* crypt_env, CSocket* socket, uint32_t ntasks, bool cardinality) { + uint32_t maskbytelen, intersectsize, i, matchctr; - perm = mask_and_permute_elements(neles, elebytelen, elements, maskbytelen, masks, crypt->get_seclvl().symbits, crypt); + uint8_t* masks; + uint32_t *perm, *invperm; + CBitVector inIntersection(neles); - sockfd->Send(masks, maskbytelen * neles); + //TODO works only fine for equally sized sets, if one set is bigger than the other, this will fail! + maskbytelen = 16;//ceil_divide(crypt_env->get_seclvl().statbits + 2*ceil_log2(neles), 8); - if(!cardinality) { - for(i = 0; i < neles; i++) { - invperm[perm[i]] = i; - } + masks = (uint8_t*) malloc(sizeof(uint8_t) * neles * maskbytelen); + perm = (uint32_t*) malloc(sizeof(uint32_t) * neles); + invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles); - map= g_hash_table_new_full(g_int64_hash, g_int64_equal, NULL, NULL); - for(i = 0; i < neles; i++) { - g_hash_table_insert(map,(void*) &((uint64_t*)masks)[2*i], &(invperm[i])); - } + /* Generate the random permutation the elements */ + crypt_env->gen_rnd_perm(perm, neles); + + socket->Send((uint8_t*) &neles, sizeof(uint32_t)); + socket->Send((uint8_t*) &maskbytelen, sizeof(uint32_t)); + + ectx.eles.outbytelen = maskbytelen, + ectx.eles.nelements = neles; + ectx.eles.output = masks; + ectx.eles.perm = perm; + ectx.sctx.symcrypt = crypt_env; + ectx.sctx.keydata = (uint8_t*) const_seed; + + run_task(ntasks, ectx, hash); + + socket->Send(masks, maskbytelen * neles); + + socket->Receive(&intersectsize, sizeof(uint32_t)); + + for(i = 0; i < neles; i++) { + invperm[perm[i]] = i; } - - sockfd->Receive(&intersectsize, sizeof(uint32_t)); if(!cardinality) { - sockfd->Receive(intersect, maskbytelen * intersectsize); + socket->Receive(inIntersection.GetArr(), ceil_divide(neles, 8)); -#ifdef DEBUG - cout << "The intersection contains " << intersectsize << " elements: " << endl; - for(i = 0; i < intersectsize; i++) { - cout << (hex) << ((uint64_t*)intersect)[2*i] << " " << ((uint64_t*)intersect)[2*i+1] << (dec) << endl; + for(i = 0, matchctr = 0; i < neles; i++) { + if(inIntersection.GetBit(i)) { + matches[matchctr] = invperm[i]; + matchctr++; + } } -#endif - - *result = (uint8_t*) malloc(elebytelen * intersectsize); - //uint8_t* tmpbuf = (uint8_t*) malloc(maskbytelen); - for(i = 0; i < intersectsize; i++) { - g_hash_table_lookup_extended(map, (void*) &(((uint64_t*)intersect)[2*i]), NULL, (void**) &tmpval); - memcpy((*result) + i * elebytelen, elements + tmpval[0] * elebytelen, elebytelen); - //crypto.decrypt(tmpbuf, intersect+i*maskbytelen, maskbytelen); - //memcpy((*result) + i * elebytelen, tmpbuf, elebytelen); -#ifdef DEBUG - cout << ((uint32_t*) elements)[tmpval[0]] << ", "; -#endif - } -#ifdef DEBUG - cout << endl; -#endif } free(perm); @@ -233,51 +233,52 @@ uint32_t client_routine(uint32_t neles, uint32_t elebytelen, uint8_t* elements, return intersectsize; } - -void printKeyValue( gpointer key, gpointer value, gpointer userData ) { - uint64_t realKey = *((uint64_t*)key); - uint64_t realValue = *((uint64_t*)value); - - cout << (hex) << realKey << ": " << realValue << (dec) << endl; - return; -} - -uint32_t* mask_and_permute_elements(uint32_t neles, uint32_t elebytelen, uint8_t* - elements, uint32_t maskbytelen, uint8_t* masks, uint32_t symsecbits, crypto* crypto) { - uint32_t* perm = (uint32_t*) malloc(sizeof(uint32_t) * neles); - uint8_t* maskpermptr; - uint32_t i; - - //Get random permutation - crypto->gen_rnd_perm(perm, neles); - //crypto->seed_aes_enc(client_psk); - - //Hash and permute all elements - for(i = 0; i < neles; i++) { - //cout << "Performing encryption for " << i << "-th element " << ((uint32_t*) elements)[i] << ": "; - maskpermptr = masks + perm[i] * maskbytelen; - crypto->hash(maskpermptr, maskbytelen, elements + i * elebytelen, elebytelen); - //crypto->encrypt(maskpermptr, elements+i*elebytelen, elebytelen); - //cout <<(hex)<< ((uint64_t*) maskpermptr)[0] << ((uint64_t*) maskpermptr)[1] << (dec) << endl; -#ifdef DEBUG - cout << "Resulting hash for element " << ((uint32_t*)elements)[i] << ": " << (hex) << ((uint64_t*) maskpermptr)[0] << - " " << ((uint64_t*) maskpermptr)[1] << (dec) << endl; -#endif - } - - - //free(perm); - return perm; -} - uint32_t ttppsi(role_type role, uint32_t neles, uint32_t elebytelen, uint8_t* elements, - uint8_t** intersection, crypto* crypt, CSocket* sockets, uint32_t nclients, bool cardinality) { + uint8_t** result, crypto* crypt, CSocket* sockets, uint32_t ntasks, uint32_t nclients, bool cardinality) { if(role == 0) { //Start the server //TODO maybe rerun infinitely server_routine(nclients, sockets, cardinality); return 0; } else { //Start clients - return client_routine(neles, elebytelen, elements, intersection, crypt, sockets, cardinality); + task_ctx ectx; + ectx.eles.input1d = elements; + ectx.eles.fixedbytelen = elebytelen; + ectx.eles.hasvarbytelen = false; + + uint32_t* matches = (uint32_t*) malloc(sizeof(uint32_t) * neles); + uint32_t intersect_size = client_routine(neles, ectx, matches, crypt, sockets, ntasks, cardinality); + + create_result_from_matches_fixed_bitlen(result, elebytelen, elements, matches, intersect_size); + + free(matches); + + return intersect_size; + } +} + + +uint32_t ttppsi(role_type role, uint32_t neles, uint32_t* elebytelens, uint8_t** elements, + uint8_t*** result, uint32_t** resbytelens, crypto* crypt, CSocket* sockets, + uint32_t ntasks, uint32_t nclients, bool cardinality) { + + if(role == 0) { //Start the server + //TODO maybe rerun infinitely + server_routine(nclients, sockets, cardinality); + return 0; + } else { //Start clients + task_ctx ectx; + ectx.eles.input2d = elements; + ectx.eles.varbytelens = elebytelens; + ectx.eles.hasvarbytelen = true; + + uint32_t* matches = (uint32_t*) malloc(sizeof(uint32_t) * neles); + uint32_t intersect_size = client_routine(neles, ectx, matches, crypt, sockets, ntasks, cardinality); + + create_result_from_matches_var_bitlen(result, resbytelens, elebytelens, elements, matches, intersect_size); + + free(matches); + + return intersect_size; } } diff --git a/src/server-aided/sapsi.h b/src/server-aided/sapsi.h index 6d36343..a507399 100644 --- a/src/server-aided/sapsi.h +++ b/src/server-aided/sapsi.h @@ -1,5 +1,5 @@ /* - * shpsi.h + * sapsi.h * * Created on: Jul 1, 2014 * Author: mzohner @@ -13,12 +13,18 @@ #include "../util/socket.h" #include "../util/typedefs.h" #include "../util/connection.h" +#include "../util/helpers.h" +#include "../util/cbitvector.h" /* start both roles*/ uint32_t ttppsi(role_type role, uint32_t neles, uint32_t elebytelen, uint8_t* elements, - uint8_t** intersection, crypto* crypt, CSocket* socket, uint32_t nclients = 0, bool cardinality=false); + uint8_t** intersection, crypto* crypt, CSocket* socket, uint32_t ntasks, uint32_t nclients = 2, bool cardinality=false); + +uint32_t ttppsi(role_type role, uint32_t neles, uint32_t* elebytelens, uint8_t** elements, + uint8_t*** result, uint32_t** resbytelens, crypto* crypt, CSocket* sockets, + uint32_t ntasks, uint32_t nclients = 2, bool cardinality = false); /* * Params: @@ -30,8 +36,8 @@ uint32_t ttppsi(role_type role, uint32_t neles, uint32_t elebytelen, uint8_t* el * port: port that the server is listening on * return: number of intersecting elements */ -uint32_t client_routine(uint32_t neles, uint32_t elebytelen, uint8_t* elements, - uint8_t** intersection, crypto* crypt, CSocket* socket, bool cardinality); +uint32_t client_routine(uint32_t neles, task_ctx ectx, uint32_t* matches, crypto* crypt, + CSocket* socket, uint32_t ntasks, bool cardinality); /* * Mask and permute the elements using the pre-shared key @@ -48,7 +54,7 @@ uint32_t* mask_and_permute_elements(uint32_t neles, uint32_t elebytelen, uint8_t */ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality); -uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, uint8_t* intersect, uint32_t entrybytelen); +uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, CBitVector* intersection, uint32_t entrybytelen); void printKeyValue( gpointer key, gpointer value, gpointer userData ); diff --git a/src/util/crypto/crypto.cpp b/src/util/crypto/crypto.cpp index d3964db..4e0148e 100644 --- a/src/util/crypto/crypto.cpp +++ b/src/util/crypto/crypto.cpp @@ -120,13 +120,13 @@ void crypto::gen_rnd_uniform(uint8_t* resbuf, uint64_t mod) { void crypto::encrypt(AES_KEY_CTX* enc_key, uint8_t* resbuf, uint8_t* inbuf, uint32_t ninbytes) { int32_t dummy; EVP_EncryptUpdate(enc_key, resbuf, &dummy, inbuf, ninbytes); - //EVP_EncryptFinal_ex(enc_key, resbuf, &dummy); + EVP_EncryptFinal_ex(enc_key, resbuf, &dummy); } void crypto::decrypt(AES_KEY_CTX* dec_key, uint8_t* resbuf, uint8_t* inbuf, uint32_t ninbytes) { int32_t dummy; //cout << "inbuf = " << (hex) << ((uint64_t*) inbuf)[0] << ((uint64_t*) inbuf)[1] << (dec) << endl; EVP_DecryptUpdate(dec_key, resbuf, &dummy, inbuf, ninbytes); - //EVP_DecryptFinal_ex(dec_key, resbuf, &dummy); + EVP_DecryptFinal_ex(dec_key, resbuf, &dummy); //cout << "outbuf = " << (hex) << ((uint64_t*) resbuf)[0] << ((uint64_t*) resbuf)[1] << (dec) << " (" << dummy << ")" << endl; } @@ -212,6 +212,10 @@ void crypto::hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t hash_routine(resbuf, noutbytes, inbuf, ninbytes, sha_hash_buf); } +void crypto::hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes, uint8_t* tmpbuf) { + hash_routine(resbuf, noutbytes, inbuf, ninbytes, tmpbuf); +} + //A fixed-key hashing scheme that uses AES, should not be used for real hashing, hashes to AES_BYTES bytes void crypto::fixed_key_aes_hash(AES_KEY_CTX* aes_key, uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes) { diff --git a/src/util/crypto/crypto.h b/src/util/crypto/crypto.h index ace7cb8..d0c46df 100644 --- a/src/util/crypto/crypto.h +++ b/src/util/crypto/crypto.h @@ -62,6 +62,7 @@ public: //Hash routines void hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes); + void hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes, uint8_t* tmpbuf); void hash_ctr(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes, uint32_t ctr); void fixed_key_aes_hash(AES_KEY_CTX* aes_key, uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes); void fixed_key_aes_hash_ctr(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes); diff --git a/src/util/helpers.h b/src/util/helpers.h index 8895786..e71a476 100644 --- a/src/util/helpers.h +++ b/src/util/helpers.h @@ -31,12 +31,12 @@ struct element_ctx { bool hasvarbytelen; }; -struct hash_ctx { +struct sym_ctx { crypto* symcrypt; - + uint8_t* keydata; }; -struct encrypt_ctx { +struct asym_ctx { num* exponent; pk_crypto* field; bool sample; @@ -45,8 +45,8 @@ struct encrypt_ctx { struct task_ctx { element_ctx eles; union { - hash_ctx hctx; - encrypt_ctx ectx; + sym_ctx sctx; + asym_ctx actx; }; }; @@ -108,20 +108,20 @@ static void create_result_from_matches_fixed_bitlen(uint8_t** result, uint32_t i } } -static void *encrypt(void* context) { +static void *asym_encrypt(void* context) { #ifdef DEBUG cout << "Encryption task started" << endl; #endif - pk_crypto* field = ((task_ctx*) context)->ectx.field; + pk_crypto* field = ((task_ctx*) context)->actx.field; element_ctx electx = ((task_ctx*) context)->eles; - num* e = ((task_ctx*) context)->ectx.exponent; + num* e = ((task_ctx*) context)->actx.exponent; fe* tmpfe = field->get_fe(); uint8_t *inptr=electx.input1d, *outptr=electx.output; uint32_t i; for(i = 0; i < electx.nelements; i++, inptr+=electx.fixedbytelen, outptr+=electx.outbytelen) { - if(((task_ctx*) context)->ectx.sample) { + if(((task_ctx*) context)->actx.sample) { tmpfe->sample_fe_from_bytes(inptr, electx.fixedbytelen); //cout << "Mapped " << ((uint32_t*) inptr)[0] << " to "; } else { @@ -135,29 +135,71 @@ static void *encrypt(void* context) { return 0; } -static void *hash(void* context) { +static void *sym_encrypt(void* context) { #ifdef DEBUG cout << "Hashing thread started" << endl; #endif - hash_ctx hdata = ((task_ctx*) context)->hctx; + sym_ctx hdata = ((task_ctx*) context)->sctx; element_ctx electx = ((task_ctx*) context)->eles; crypto* crypt_env = hdata.symcrypt; + AES_KEY_CTX aes_key; + //cout << "initializing key" << endl; + crypt_env->init_aes_key(&aes_key, hdata.keydata); + //cout << "initialized key" << endl; + + uint8_t* aes_buf = (uint8_t*) malloc(AES_BYTES); uint32_t* perm = electx.perm; uint32_t i; if(electx.hasvarbytelen) { uint8_t **inptr = electx.input2d; for(i = electx.startelement; i < electx.endelement; i++) { - crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr[i], electx.varbytelens[i]); + //crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr[i], electx.varbytelens[i]); + //cout << "encrypting i = " << i << ", perm = " << perm [i] << ", outbytelen = " << electx.outbytelen << endl; + crypt_env->encrypt(&aes_key, aes_buf, inptr[i], electx.varbytelens[i]); + memcpy(electx.output+perm[i]*electx.outbytelen, aes_buf, electx.outbytelen); } } else { uint8_t *inptr = electx.input1d; for(i = electx.startelement; i < electx.endelement; i++, inptr+=electx.fixedbytelen) { - crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr, electx.fixedbytelen); + //crypt_env->hash(&aes_key, electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr, electx.fixedbytelen); + crypt_env->encrypt(&aes_key, aes_buf, inptr, electx.fixedbytelen); + memcpy(electx.output+perm[i]*electx.outbytelen, aes_buf, electx.outbytelen); } } + + //cout << "Returning" << endl; + //free(aes_buf); + return 0; +} + +static void *hash(void* context) { +#ifdef DEBUG + cout << "Hashing thread started" << endl; +#endif + sym_ctx hdata = ((task_ctx*) context)->sctx; + element_ctx electx = ((task_ctx*) context)->eles; + + crypto* crypt_env = hdata.symcrypt; + + uint32_t* perm = electx.perm; + uint32_t i; + uint8_t* tmphashbuf = (uint8_t*) malloc(crypt_env->get_hash_bytes()); + + if(electx.hasvarbytelen) { + uint8_t **inptr = electx.input2d; + for(i = electx.startelement; i < electx.endelement; i++) { + crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr[i], electx.varbytelens[i], tmphashbuf); + } + } else { + uint8_t *inptr = electx.input1d; + for(i = electx.startelement; i < electx.endelement; i++, inptr+=electx.fixedbytelen) { + crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr, electx.fixedbytelen, tmphashbuf); + } + } + free(tmphashbuf); return 0; } @@ -198,8 +240,6 @@ static void run_task(uint32_t nthreads, task_ctx context, void* (*func)(void*) ) neles_cur = min(context.eles.nelements - electr, neles_thread); memcpy(contexts + i, &context, sizeof(task_ctx)); contexts[i].eles.nelements = neles_cur; - //contexts[i].eles.input = context.eles.input + (context.eles.inbytelen * electr); - //contexts[i].eles.output = context.eles.output + (context.eles.outbytelen * electr); contexts[i].eles.startelement = electr; contexts[i].eles.endelement = electr + neles_cur; electr += neles_cur; @@ -228,7 +268,6 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas uint32_t hashbytelen, uint32_t* perm, uint32_t* matches) { uint32_t* invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles); - //uint32_t* matches = (uint32_t*) malloc(sizeof(uint32_t) * neles); uint64_t* tmpval; uint32_t size_intersect, i, intersect_ctr; @@ -236,17 +275,12 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas for(i = 0; i < neles; i++) { invperm[perm[i]] = i; } - //cout << "My number of elements. " << neles << ", partner number of elements: " << pneles << ", maskbytelen: " << hashbytelen << endl; GHashTable *map= g_hash_table_new_full(g_int64_hash, g_int64_equal, NULL, NULL); for(i = 0; i < neles; i++) { g_hash_table_insert(map,(void*) ((uint64_t*) &(hashes[i*hashbytelen])), &(invperm[i])); } - //for(i = 0; i < pneles; i++) { - // ((uint64_t*) &(phashes[i*hashbytelen]))[0]++; - //} - for(i = 0, intersect_ctr = 0; i < pneles; i++) { if(g_hash_table_lookup_extended(map, (void*) ((uint64_t*) &(phashes[i*hashbytelen])), @@ -259,14 +293,7 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas size_intersect = intersect_ctr; - //result = (uint8_t**) malloc(sizeof(uint8_t*)); - //(*result) = (uint8_t*) malloc(sizeof(uint8_t) * size_intersect * elebytelen); - //for(i = 0; i < size_intersect; i++) { - // memcpy((*result) + i * elebytelen, elements + matches[i] * elebytelen, elebytelen); - //} - free(invperm); - //free(matches); return size_intersect; }