diff --git a/Makefile b/Makefile index 65cf983..20bef60 100644 --- a/Makefile +++ b/Makefile @@ -72,6 +72,12 @@ bench: demo: ${CC} -o demo.exe ${SRC}/mains/psi_demo.cpp ${OBJECTS_DHPSI} ${OBJECTS_OTPSI} ${OBJECTS_NAIVE} ${OBJECTS_SERVERAIDED} ${OBJECTS_UTIL} ${OBJECTS_HASHING} ${OBJECTS_CRYPTO} ${OBJECTS_OT} ${OBJECTS_MIRACL} ${CFLAGS} ${DEBUG_OPTIONS} ${LIBRARIES} ${MIRACL_LIB} ${INCLUDE} ${COMPILER_OPTIONS} +test: + ${CC} -o test.exe ${SRC}/mains/test_psi.cpp ${OBJECTS_DHPSI} ${OBJECTS_OTPSI} ${OBJECTS_NAIVE} ${OBJECTS_SERVERAIDED} ${OBJECTS_UTIL} ${OBJECTS_HASHING} ${OBJECTS_CRYPTO} ${OBJECTS_OT} ${OBJECTS_MIRACL} ${CFLAGS} ${DEBUG_OPTIONS} ${LIBRARIES} ${MIRACL_LIB} ${INCLUDE} ${COMPILER_OPTIONS} + ./test.exe -r 0 -t 10 & + ./test.exe -r 1 -t 10 + + cuckoo: ${CC} -o cuckoo.exe ${SRC}/mains/cuckoo_analysis.cpp ${OBJECTS_UTIL} ${OBJECTS_HASHING} ${OBJECTS_CRYPTO} ${OBJECTS_MIRACL} ${CFLAGS} ${DEBUG_OPTIONS} ${LIBRARIES} ${MIRACL_LIB} ${INCLUDE} ${COMPILER_OPTIONS} diff --git a/README.md b/README.md index e49b248..aef6297 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,25 @@ These commands will run the naive hashing protocol and compute the intersection For further information about the program options, run ```./demo.exe -h```. +### Testing the Protocols + +The protocols will automatically be tested on randomly generated data when invoking: +``` + make test +``` + +WARNING: Some tests can still fail since the code is currently being debugged. + +### Generating Random Email Adresses + +Further random email adresses can be generated by navigating to `/sample_sets/emailgenerator/` and invoking: + +``` + ./emailgenerator.py "number_of_emails" +``` + +The generator uses the first names, family names, and email providers listed in the corresponding files in `sample_sets/emailgenerator/` as base for the generation. + ### References [1] B. Pinkas, T. Schneider, M. Zohner. Faster Private Set Intersection Based on OT Extension. USENIX Security 2014: 797-812. Full version available at http://eprint.iacr.org/2014/447. diff --git a/src/hashing/cuckoo.cpp b/src/hashing/cuckoo.cpp index 948297e..ac6630d 100644 --- a/src/hashing/cuckoo.cpp +++ b/src/hashing/cuckoo.cpp @@ -165,7 +165,7 @@ inline void gen_cuckoo_entry(uint8_t* in, cuckoo_entry_ctx* out, hs_t* hs, uint3 out->eleid = ele_id; #ifndef TEST_UTILIZATION - out->val = (uint8_t*) malloc(hs->outbytelen); + out->val = (uint8_t*) calloc(hs->outbytelen, sizeof(uint8_t)); #endif hashElement(in, out->address, out->val, hs); } diff --git a/src/hashing/hashing_util.h b/src/hashing/hashing_util.h index bdad1ac..99223d1 100644 --- a/src/hashing/hashing_util.h +++ b/src/hashing/hashing_util.h @@ -56,6 +56,7 @@ static const uint32_t SELECT_BITS_INV[33] = \ 0xFF000000, 0xFE000000, 0xFC000000, 0xF8000000, 0xF0000000, 0xE0000000, 0xC0000000, 0x80000000, \ 0x00000000 }; +static const uint8_t BYTE_SELECT_BITS_INV[8] = {0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01}; //Init the values for the hash function static void init_hashing_state(hs_t* hs, uint32_t nelements, uint32_t inbitlen, uint32_t nbins, @@ -168,8 +169,26 @@ inline void hashElement(uint8_t* element, uint32_t* address, uint8_t* val, hs_t* *((uint32_t*) val) = R; //TODO copy remaining bits - if(hs->outbytelen >= sizeof(uint32_t)) - memcpy(val + (sizeof(uint32_t) - hs->addrbytelen), element + sizeof(uint32_t), hs->outbytelen - sizeof(uint32_t)); + //if(hs->outbytelen >= sizeof(uint32_t)) + if(hs->outbitlen + hs->addrbitlen >= sizeof(uint32_t) * 8) { + //memcpy(val + (sizeof(uint32_t) - hs->addrbytelen), element + sizeof(uint32_t), hs->outbytelen - (sizeof(uint32_t) - hs->addrbytelen)); + memcpy(val + (sizeof(uint32_t) - (hs->addrbitlen >>3)), element + sizeof(uint32_t), hs->outbytelen - (sizeof(uint32_t) - (hs->addrbitlen >>3))); + + //cout << "Element: "<< (hex) << (uint32_t) val[hs->outbytelen-1] << ", " << (uint32_t) (BYTE_SELECT_BITS_INV[hs->outbitlen & 0x03]) + // << ", " << (uint32_t) (val[hs->outbytelen-1] & (BYTE_SELECT_BITS_INV[hs->outbitlen & 0x03]) )<< (dec) << " :"; + + val[hs->outbytelen-1] &= (BYTE_SELECT_BITS_INV[hs->outbitlen & 0x03]); + + /*for(i = 0; i < hs->inbytelen; i++) { + cout << (hex) << (uint32_t) element[i]; + } + cout << ", "; + for(i = 0; i < hs->outbytelen; i++) { + cout << (hex) << (uint32_t) val[i]; + } + cout << (dec) << endl;*/ + } + #endif //cout << "Address for hfid = " << hfid << ": " << *address << ", L = " << L << ", R = " << R << endl; diff --git a/src/hashing/simple_hashing.cpp b/src/hashing/simple_hashing.cpp index 3d11f7c..55ce6c6 100644 --- a/src/hashing/simple_hashing.cpp +++ b/src/hashing/simple_hashing.cpp @@ -103,7 +103,7 @@ void *gen_entries(void *ctx_tmp) { uint32_t i, inbytelen, *address; address = (uint32_t*) malloc(NUM_HASH_FUNCTIONS * sizeof(uint32_t)); - tmpbuf = (uint8_t*) malloc(ceil_divide(ctx->hs->outbitlen, 8)); //for(i = 0; i < NUM_HASH_FUNCTIONS; i++) { + tmpbuf = (uint8_t*) calloc(ceil_divide(ctx->hs->outbitlen, 8), sizeof(uint8_t)); //for(i = 0; i < NUM_HASH_FUNCTIONS; i++) { // tmpbuf[i] = (uint8_t*) malloc(ceil_divide(ctx->hs->outbitlen, 8)); //} @@ -132,7 +132,10 @@ inline void insert_element(sht_ctx* table, uint8_t* element, uint32_t* address, } tmp_bin->nvals++; //TODO: or simply allocate a bigger block of memory: table->maxbinsize * 2, left out for efficiency reasons - assert(tmp_bin->nvals < table->maxbinsize); + if(tmp_bin->nvals == table->maxbinsize) { + increase_max_bin_size(table, hs->outbytelen); + } + //assert(tmp_bin->nvals < table->maxbinsize); /*cout << "Inserted into bin: " << address << ": " << (hex); for(uint32_t j = 0; j < table->outbytelen; j++) { cout << (unsigned int) tmpbuf[j]; @@ -142,7 +145,6 @@ inline void insert_element(sht_ctx* table, uint8_t* element, uint32_t* address, } } - void init_hash_table(sht_ctx* table, uint32_t nelements, hs_t* hs) { uint32_t i; @@ -176,3 +178,15 @@ void free_hash_table(sht_ctx* table) { //3. free the actual table //free(table); } + +void increase_max_bin_size(sht_ctx* table, uint32_t valbytelen) { + uint32_t new_maxsize = table->maxbinsize * 2; + uint8_t* tmpvals; + for(uint32_t i = 0; i < table->nbins; i++) { + tmpvals = table->bins[i].values; + table->bins[i].values = (uint8_t*) malloc(new_maxsize * valbytelen); + memcpy(table->bins[i].values, tmpvals, table->bins[i].nvals * valbytelen); + free(tmpvals); + } + table->maxbinsize = new_maxsize; +} diff --git a/src/hashing/simple_hashing.h b/src/hashing/simple_hashing.h index 7b7038e..d06e5d7 100644 --- a/src/hashing/simple_hashing.h +++ b/src/hashing/simple_hashing.h @@ -48,6 +48,7 @@ uint8_t* simple_hashing(uint8_t* elements, uint32_t neles, uint32_t bitlen, uint //routine for generating the entries, is invoked by the threads void *gen_entries(void *ctx); void init_hash_table(sht_ctx* table, uint32_t nelements, hs_t* hs); +void increase_max_bin_size(sht_ctx* table, uint32_t valbytelen); void free_hash_table(sht_ctx* table); inline void insert_element(sht_ctx* table, uint8_t* element, uint32_t* address, uint8_t* tmpbuf, hs_t* hs); diff --git a/src/mains/test_psi.cpp b/src/mains/test_psi.cpp index 0da42fd..6079b70 100644 --- a/src/mains/test_psi.cpp +++ b/src/mains/test_psi.cpp @@ -5,6 +5,243 @@ * Author: mzohner */ +#include "test_psi.h" + +int32_t main(int32_t argc, char** argv) { + string address="127.0.0.1"; + uint32_t nelements, elebytelen, ntasks=1, nruns=1, symsecbits=128; + uint64_t rnd; + role_type role = (role_type) 0; + vector sockfd(ntasks); + uint16_t port=7766; + uint8_t* seed = (uint8_t*) malloc(AES_BYTES); + + + read_psi_test_options(&argc, &argv, &role, &nruns); + + memcpy(seed, const_seed, AES_BYTES); + seed[0] = role; + crypto* crypt = new crypto(symsecbits, seed); + + crypt->gen_rnd((uint8_t*) &rnd, sizeof(uint64_t)); + srand((unsigned)rnd+time(0)); + + if(role == SERVER) { + listen(address.c_str(), port, sockfd.data(), ntasks); + } else { + for(uint32_t i = 0; i < ntasks; i++) + connect(address.c_str(), port, sockfd[i]); + } + + for(uint32_t i = 0; i < nruns; i++) { + if(role == CLIENT) cout << "Running test on iteration " << i << std::flush; + nelements = rand() % (1<<12); + elebytelen = (rand() % 14) + 2; + test_psi_prot(role, sockfd.data(), nelements, elebytelen, crypt); + if(role == CLIENT) cout << endl; + } + + cout << "All tests successfully passed" << endl; +} + + +uint32_t test_psi_prot(role_type role, CSocket* sock, uint32_t nelements, + uint32_t elebytelen, crypto* crypt) { + double epsilon=1.2; + uint32_t p_inter_size, n_inter_size, ot_inter_size, dh_inter_size, i, j, ntasks=1, + pnelements, nclients = 2; + uint8_t *elements, *pelements, *p_intersection, *n_intersection, *ot_intersection, *dh_intersection; + + //if(protocol != TTP) + + pnelements = set_up_parameters(role, nelements, &elebytelen, &elements, &pelements, sock[0], crypt); + + p_inter_size = plaintext_intersect(nelements, pnelements, elebytelen, elements, pelements, + &p_intersection); + //cout << "Plaintext intersection computed " << endl; + if(role == CLIENT) cout << "." << std::flush; + + n_inter_size = naivepsi(role, nelements, pnelements, elebytelen, elements, &n_intersection, crypt, + sock, ntasks); + //cout << "Naive intersection computed " << endl; + if(role == CLIENT) cout << "." << std::flush; + + dh_inter_size = dhpsi(role, nelements, pnelements, elebytelen, elements, &dh_intersection, crypt, + sock, ntasks); + //cout << "DH intersection computed " << endl; + if(role == CLIENT) cout << "." << std::flush; + + ot_inter_size = otpsi(role, nelements, pnelements, elebytelen, elements, &ot_intersection, + crypt, sock, ntasks, epsilon); + //cout << "OT intersection computed " << endl; + if(role == CLIENT) cout << "." << std::flush; + + + if(role == CLIENT) { + bool success = true; + success &= (p_inter_size == n_inter_size); + success &= (p_inter_size == dh_inter_size); + success &= (p_inter_size == ot_inter_size); + + for(uint32_t i = 0; i < p_inter_size * elebytelen; i++) { + success &= (p_intersection[i] == n_intersection[i]); + success &= (p_intersection[i] == dh_intersection[i]); + success &= (p_intersection[i] == ot_intersection[i]); + } + + if(!success) { + cout << "Error in tests for " << nelements << " and " << pnelements << " on " << elebytelen + << " byte length: " << endl; + + cout << "\t" << p_inter_size << " elements in verification intersection" << endl; + cout << "\t" << n_inter_size << " elements in naive intersection" << endl; + cout << "\t" << dh_inter_size << " elements in DH intersection" << endl; + cout << "\t" << ot_inter_size << " elements in OT intersection" << endl; + + cout << "Plaintext intersection (" << p_inter_size << "): " << endl; + plot_set(p_intersection, p_inter_size, elebytelen); + cout << "Naive intersection (" << n_inter_size << "): " << endl; + plot_set(n_intersection, n_inter_size, elebytelen); + cout << "DH intersection (" << dh_inter_size << "): " << endl; + plot_set(dh_intersection, dh_inter_size, elebytelen); + cout << "OT intersection: (" << ot_inter_size << "): " << endl; + plot_set(ot_intersection, ot_inter_size, elebytelen); + } + + if(p_inter_size > 0) + free(p_intersection); + if(n_inter_size > 0) + free(n_intersection); + if(dh_inter_size > 0) + free(dh_intersection); + if(ot_inter_size > 0) + free(ot_intersection); + + assert(success); + } + + + free(elements); + free(pelements); + + + return 1; +} + +void plot_set(uint8_t* set, uint32_t neles, uint32_t elebytelen) { + for(uint32_t i = 0; i < neles; i++) { + cout << i << ": "; + for(uint32_t j = 0; j < elebytelen; j++) { + cout << setw(2) << setfill('0') << (hex) << (uint32_t) set[i*elebytelen+j]; + } + cout << (dec) << endl; + } +} + + +uint32_t plaintext_intersect(uint32_t myneles, uint32_t pneles, uint32_t bytelen, uint8_t* myelements, + uint8_t* pelements, uint8_t** result) { + uint32_t intersect_size = 0, i, j; + uint64_t tmpkey = 0; + uint8_t *tmpval; + uint8_t** matches = (uint8_t**) malloc(sizeof(uint8_t*) * min(myneles, pneles)); + uint32_t keylen = min((uint32_t) bytelen, (uint32_t) 8); + bool success; + + + GHashTable *map= g_hash_table_new_full(g_int64_hash, g_int64_equal, NULL, NULL); + for(i = 0; i < myneles; i++) { + memcpy(&tmpkey, myelements+i*bytelen, keylen); + g_hash_table_insert(map,(void*) &tmpkey, myelements+i*bytelen); + } + + for(i = 0; i < pneles; i++) { + memcpy(&tmpkey, pelements+i*bytelen, keylen); + if(g_hash_table_lookup_extended(map, (void*) &tmpkey, NULL, (void**) &tmpval)) { + success = true; + if(bytelen > 8) { + for(j = 8; j < bytelen && success; j++) { + if(tmpval[j] != pelements[i*bytelen+j]) + success = false; + } + } + if(success) { + matches[intersect_size] = (uint8_t*) tmpval; + intersect_size++; + } + + assert(intersect_size <= min(myneles, pneles)); + } + } + + *result = (uint8_t*) malloc(intersect_size * bytelen); + + for(i = 0; i < intersect_size; i++) { + memcpy((*result) + i * bytelen, matches[i], bytelen); + } + + free(matches); + return intersect_size; +} + + +uint32_t set_up_parameters(role_type role, uint32_t myneles, uint32_t* mybytelen, + uint8_t** elements, uint8_t** pelements, CSocket& sock, crypto* crypt) { + + uint32_t pneles, nintersections; + + //Exchange meta-information and equalize byte-length + sock.Send(&myneles, sizeof(uint32_t)); + sock.Receive(&pneles, sizeof(uint32_t)); + + if(role == SERVER) { + sock.Send(mybytelen, sizeof(uint32_t)); + } else { + sock.Receive(mybytelen, sizeof(uint32_t)); + } + *elements = (uint8_t*) malloc(myneles * *mybytelen); + *pelements = (uint8_t*) malloc(pneles * *mybytelen); + + crypt->gen_rnd(*elements, myneles * *mybytelen); + + //Exchange elements for later check + if(role == SERVER) { + sock.Send(*elements, myneles * *mybytelen); + sock.Receive(*pelements, pneles * *mybytelen); + } else { //have the client use some of the servers values s.t. the intersection is not disjoint + sock.Receive(*pelements, pneles * *mybytelen); + nintersections = rand() % min(myneles, pneles); + for(uint32_t i = 0; i < nintersections; i++) { + memcpy(*elements + i * *mybytelen, *pelements + i * *mybytelen, *mybytelen); + } + sock.Send(*elements, myneles * *mybytelen); + } + + + + //memset(*elements, 0x00, *mybytelen); + //memset(*pelements, 0x00, *mybytelen); + + return pneles; +} + + +int32_t read_psi_test_options(int32_t* argcp, char*** argvp, role_type* role, uint32_t* nruns) { + uint32_t int_role; + parsing_ctx options[] = {{(void*) &int_role, T_NUM, 'r', "Role: 0/1", true, false}, + {(void*) nruns, T_NUM, 't', "#of test iterations", false, false}, + }; + + if(!parse_options(argcp, argvp, options, sizeof(options)/sizeof(parsing_ctx))) { + print_usage(argvp[0][0], options, sizeof(options)/sizeof(parsing_ctx)); + exit(0); + } + + assert(int_role < 2); + *role = (role_type) int_role; + + return 1; +} diff --git a/src/mains/test_psi.h b/src/mains/test_psi.h index 4633fdc..4ae7b9e 100644 --- a/src/mains/test_psi.h +++ b/src/mains/test_psi.h @@ -8,6 +8,28 @@ #ifndef TEST_PSI_H_ #define TEST_PSI_H_ +#define SILENT_TESTS +#include +#include +#include +#include +#include +#include "../pk-based/dh-psi.h" +#include "../ot-based/ot-psi.h" +#include "../server-aided/sapsi.h" +#include "../naive-hashing/naive-psi.h" +#include "../util/parse_options.h" +#include "../util/helpers.h" + + +uint32_t test_psi_prot(role_type role, CSocket* sock, uint32_t nelements, + uint32_t elebytelen, crypto* crypt); +uint32_t plaintext_intersect(uint32_t myneles, uint32_t pneles, uint32_t bytelen, uint8_t* myelements, + uint8_t* pelements, uint8_t** result); +uint32_t set_up_parameters(role_type role, uint32_t myneles, uint32_t* mybytelen, + uint8_t** elements, uint8_t** pelements, CSocket& sock, crypto* crypt); +int32_t read_psi_test_options(int32_t* argcp, char*** argvp, role_type* role, uint32_t* nruns); +void plot_set(uint8_t* set, uint32_t neles, uint32_t elebytelen); #endif /* TEST_PSI_H_ */ diff --git a/src/ot-based/ot-psi.cpp b/src/ot-based/ot-psi.cpp index b001348..3e6321c 100644 --- a/src/ot-based/ot-psi.cpp +++ b/src/ot-based/ot-psi.cpp @@ -52,15 +52,6 @@ uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t* elebyt intersect_size = otpsi_client(eleptr, neles, nbins, pneles, internal_bitlen, maskbitlen, crypt_env, sock, ntasks, &prf_state, &res_pos); - //std::sort(res_pos, res_pos+intersect_size); - - //*result = (uint8_t**) malloc(intersect_size * sizeof(uint8_t*)); - //*res_bytelen = (uint32_t*) malloc(intersect_size * sizeof(uint32_t)); - /*for(i = 0; i < intersect_size; i++) { - (*res_bytelen)[i] = elebytelens[res_pos[i]]; - (*result)[i] = (uint8_t*) malloc((*res_bytelen)[i]); - memcpy((*result)[i], elements[res_pos[i]], (*res_bytelen)[i]); - }*/ create_result_from_matches_var_bitlen(result, res_bytelen, elebytelens, elements, res_pos, intersect_size); } @@ -71,12 +62,12 @@ uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t* elebyt -uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t elebitlen, uint8_t* elements, +uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t elebytelen, uint8_t* elements, uint8_t** result, crypto* crypt_env, CSocket* sock, uint32_t ntasks, double epsilon, bool detailed_timings) { prf_state_ctx prf_state; - uint32_t maskbytelen, nbins, intersect_size, internal_bitlen, maskbitlen, *res_pos, i, elebytelen; + uint32_t maskbytelen, nbins, intersect_size, internal_bitlen, maskbitlen, *res_pos, i, elebitlen; uint8_t *eleptr; timeval t_start, t_end; @@ -84,16 +75,16 @@ uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t elebitl maskbitlen = pad_to_multiple(crypt_env->get_seclvl().statbits + ceil_log2(neles) + ceil_log2(pneles), 8); maskbytelen = ceil_divide(maskbitlen, 8); - elebytelen = ceil_divide(elebitlen, 8); + elebitlen = elebytelen * 8; if(elebitlen > maskbitlen) { //Hash elements into a smaller domain eleptr = (uint8_t*) malloc(maskbytelen * neles); - domain_hashing(neles, elements, ceil_divide(elebitlen, 8), eleptr, maskbytelen, crypt_env); + domain_hashing(neles, elements, elebytelen, eleptr, maskbytelen, crypt_env); internal_bitlen = maskbitlen; #ifndef BATCH cout << "Hashing " << neles << " elements with " << elebitlen << " bit-length into " << - maskbitlen << " representation " << endl; + maskbitlen << " bit representation " << endl; #endif } else { eleptr = elements; @@ -110,10 +101,11 @@ uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t elebitl nbins = ceil(epsilon * neles); intersect_size = otpsi_client(eleptr, neles, nbins, pneles, internal_bitlen, maskbitlen, crypt_env, sock, ntasks, &prf_state, &res_pos); - *result = (uint8_t*) malloc(intersect_size * elebytelen); - for(i = 0; i < intersect_size; i++) { - memcpy((*result) + i * elebytelen, elements + res_pos[i] * elebytelen, elebytelen); - } + //*result = (uint8_t*) malloc(intersect_size * elebytelen); + //for(i = 0; i < intersect_size; i++) { + // memcpy((*result) + i * elebytelen, elements + res_pos[i] * elebytelen, elebytelen); + //} + create_result_from_matches_fixed_bitlen(result, elebytelen, elements, res_pos, intersect_size); } if(elebitlen > maskbitlen) @@ -176,27 +168,23 @@ uint32_t otpsi_client(uint8_t* elements, uint32_t neles, uint32_t nbins, uint32_ masks = (uint8_t*) malloc(neles * maskbytelen); //Perform the OPRG execution + //cout << "otpsi client running ots" << endl; oprg_client(hash_table, nbins, neles, nelesinbin, outbitlen, maskbitlen, crypt_env, sock, ntasks, masks); if(DETAILED_TIMINGS) { gettimeofday(&t_start, NULL); } - /*uint64_t tmpmask = 0; - for(uint32_t i = 0; i < neles-1; i++) { - memcpy((uint8_t*) &tmpmask, masks + i*maskbytelen, maskbytelen); - cout << "Mask " << i << " : " << (hex) << tmpmask << (dec) << endl; //"intersection found at position " << tmpval[0] << " for key " << tmpbuf[0] << endl; - }*/ #ifdef TIMING gettimeofday(&t_end, NULL); cout << "Client: time for OPRG evaluation: " << getMillies(t_start, t_end) << " ms" << endl; gettimeofday(&t_start, NULL); #endif -/*#ifdef PRINT_BIN_CONTENT +#ifdef PRINT_BIN_CONTENT cout << "Client masks: " << endl; print_bin_content(masks, neles, maskbytelen, NULL, false); -#endif*/ +#endif //receive server masks server_masks = (uint8_t*) malloc(NUM_HASH_FUNCTIONS * pneles * maskbytelen); @@ -265,6 +253,7 @@ uint32_t otpsi_client(uint8_t* elements, uint32_t neles, uint32_t nbins, uint32_ cout << "Time for intersecting:\t\t" << fixed << std::setprecision(2) << getMillies(t_start, t_end) << " ms" << endl; } + free(masks); free(hash_table); free(nelesinbin); @@ -720,7 +709,7 @@ uint32_t otpsi_find_intersection(uint32_t** result, uint8_t* my_hashes, } else { keys_stored = 1; tmp_hashbytelen = hashbytelen; - tmpkeys = (uint32_t*) malloc(my_neles * keys_stored * sizeof(uint32_t)); + tmpkeys = (uint32_t*) malloc(my_neles * sizeof(uint32_t)); memcpy(tmpkeys, perm, my_neles * sizeof(uint32_t)); } @@ -760,7 +749,6 @@ uint32_t otpsi_find_intersection(uint32_t** result, uint8_t* my_hashes, } } - //TODO: workaround since the masks that are inserted into the hash table are too small and collisions occur if(intersect_ctr > my_neles) { cout << "more intersections than elements: " << intersect_ctr << " vs " << my_neles << endl; intersect_ctr = my_neles; @@ -775,8 +763,8 @@ uint32_t otpsi_find_intersection(uint32_t** result, uint8_t* my_hashes, //cout << "I found " << size_intersect << " intersecting elements" << endl; free(matches); - free(map); free(invperm); + free(tmpkeys); return size_intersect; } diff --git a/src/util/helpers.h b/src/util/helpers.h index e71a476..247c575 100644 --- a/src/util/helpers.h +++ b/src/util/helpers.h @@ -91,6 +91,8 @@ static void create_result_from_matches_var_bitlen(uint8_t*** result, uint32_t** *result = (uint8_t**) malloc(sizeof(uint8_t*) * intersect_size); *resbytelens = (uint32_t*) malloc(sizeof(uint32_t) * intersect_size); + std::sort(matches, matches+intersect_size); + for(i = 0; i < intersect_size; i++) { (*resbytelens)[i] = inbytelens[matches[i]]; (*result)[i] = (uint8_t*) malloc((*resbytelens)[i]); @@ -101,10 +103,12 @@ static void create_result_from_matches_var_bitlen(uint8_t*** result, uint32_t** static void create_result_from_matches_fixed_bitlen(uint8_t** result, uint32_t inbytelen, uint8_t* inputs, uint32_t* matches, uint32_t intersect_size) { uint32_t i; - *result = (uint8_t*) malloc(sizeof(uint8_t) * intersect_size); + *result = (uint8_t*) malloc(inbytelen * intersect_size); + + std::sort(matches, matches+intersect_size); for(i = 0; i < intersect_size; i++) { - memcpy(result + i * inbytelen, inputs + matches[i] * inbytelen, inbytelen); + memcpy(*(result) + i * inbytelen, inputs + matches[i] * inbytelen, inbytelen); } } @@ -268,8 +272,8 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas uint32_t hashbytelen, uint32_t* perm, uint32_t* matches) { uint32_t* invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles); - uint64_t* tmpval; - + uint64_t *tmpval, tmpkey = 0; + uint32_t mapbytelen = min((uint32_t) hashbytelen, (uint32_t) sizeof(uint64_t)); uint32_t size_intersect, i, intersect_ctr; for(i = 0; i < neles; i++) { @@ -278,13 +282,13 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas GHashTable *map= g_hash_table_new_full(g_int64_hash, g_int64_equal, NULL, NULL); for(i = 0; i < neles; i++) { - g_hash_table_insert(map,(void*) ((uint64_t*) &(hashes[i*hashbytelen])), &(invperm[i])); + memcpy(&tmpkey, hashes + i*hashbytelen, mapbytelen); + g_hash_table_insert(map,(void*) &tmpkey, &(invperm[i])); } for(i = 0, intersect_ctr = 0; i < pneles; i++) { - - if(g_hash_table_lookup_extended(map, (void*) ((uint64_t*) &(phashes[i*hashbytelen])), - NULL, (void**) &tmpval)) { + memcpy(&tmpkey, phashes+ i*hashbytelen, mapbytelen); + if(g_hash_table_lookup_extended(map, (void*) &tmpkey, NULL, (void**) &tmpval)) { matches[intersect_ctr] = tmpval[0]; intersect_ctr++; assert(intersect_ctr <= min(neles, pneles)); diff --git a/src/util/typedefs.h b/src/util/typedefs.h index 64d0747..a0bc0a9 100644 --- a/src/util/typedefs.h +++ b/src/util/typedefs.h @@ -9,7 +9,7 @@ #define TYPEDEFS_H_ //#define DEBUG -//#define BATCH +#define BATCH //#define TIMING #include