Added test routines

This commit is contained in:
Michael Zohner 2015-06-09 17:11:57 +02:00
parent 849f790167
commit 70d49f1871
11 changed files with 353 additions and 43 deletions

View File

@ -72,6 +72,12 @@ bench:
demo:
${CC} -o demo.exe ${SRC}/mains/psi_demo.cpp ${OBJECTS_DHPSI} ${OBJECTS_OTPSI} ${OBJECTS_NAIVE} ${OBJECTS_SERVERAIDED} ${OBJECTS_UTIL} ${OBJECTS_HASHING} ${OBJECTS_CRYPTO} ${OBJECTS_OT} ${OBJECTS_MIRACL} ${CFLAGS} ${DEBUG_OPTIONS} ${LIBRARIES} ${MIRACL_LIB} ${INCLUDE} ${COMPILER_OPTIONS}
test:
${CC} -o test.exe ${SRC}/mains/test_psi.cpp ${OBJECTS_DHPSI} ${OBJECTS_OTPSI} ${OBJECTS_NAIVE} ${OBJECTS_SERVERAIDED} ${OBJECTS_UTIL} ${OBJECTS_HASHING} ${OBJECTS_CRYPTO} ${OBJECTS_OT} ${OBJECTS_MIRACL} ${CFLAGS} ${DEBUG_OPTIONS} ${LIBRARIES} ${MIRACL_LIB} ${INCLUDE} ${COMPILER_OPTIONS}
./test.exe -r 0 -t 10 &
./test.exe -r 1 -t 10
cuckoo:
${CC} -o cuckoo.exe ${SRC}/mains/cuckoo_analysis.cpp ${OBJECTS_UTIL} ${OBJECTS_HASHING} ${OBJECTS_CRYPTO} ${OBJECTS_MIRACL} ${CFLAGS} ${DEBUG_OPTIONS} ${LIBRARIES} ${MIRACL_LIB} ${INCLUDE} ${COMPILER_OPTIONS}

View File

@ -68,6 +68,25 @@ These commands will run the naive hashing protocol and compute the intersection
For further information about the program options, run ```./demo.exe -h```.
### Testing the Protocols
The protocols will automatically be tested on randomly generated data when invoking:
```
make test
```
WARNING: Some tests can still fail since the code is currently being debugged.
### Generating Random Email Adresses
Further random email adresses can be generated by navigating to `/sample_sets/emailgenerator/` and invoking:
```
./emailgenerator.py "number_of_emails"
```
The generator uses the first names, family names, and email providers listed in the corresponding files in `sample_sets/emailgenerator/` as base for the generation.
### References
[1] B. Pinkas, T. Schneider, M. Zohner. Faster Private Set Intersection Based on OT Extension. USENIX Security 2014: 797-812. Full version available at http://eprint.iacr.org/2014/447.

View File

@ -165,7 +165,7 @@ inline void gen_cuckoo_entry(uint8_t* in, cuckoo_entry_ctx* out, hs_t* hs, uint3
out->eleid = ele_id;
#ifndef TEST_UTILIZATION
out->val = (uint8_t*) malloc(hs->outbytelen);
out->val = (uint8_t*) calloc(hs->outbytelen, sizeof(uint8_t));
#endif
hashElement(in, out->address, out->val, hs);
}

View File

@ -56,6 +56,7 @@ static const uint32_t SELECT_BITS_INV[33] = \
0xFF000000, 0xFE000000, 0xFC000000, 0xF8000000, 0xF0000000, 0xE0000000, 0xC0000000, 0x80000000, \
0x00000000 };
static const uint8_t BYTE_SELECT_BITS_INV[8] = {0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01};
//Init the values for the hash function
static void init_hashing_state(hs_t* hs, uint32_t nelements, uint32_t inbitlen, uint32_t nbins,
@ -168,8 +169,26 @@ inline void hashElement(uint8_t* element, uint32_t* address, uint8_t* val, hs_t*
*((uint32_t*) val) = R;
//TODO copy remaining bits
if(hs->outbytelen >= sizeof(uint32_t))
memcpy(val + (sizeof(uint32_t) - hs->addrbytelen), element + sizeof(uint32_t), hs->outbytelen - sizeof(uint32_t));
//if(hs->outbytelen >= sizeof(uint32_t))
if(hs->outbitlen + hs->addrbitlen >= sizeof(uint32_t) * 8) {
//memcpy(val + (sizeof(uint32_t) - hs->addrbytelen), element + sizeof(uint32_t), hs->outbytelen - (sizeof(uint32_t) - hs->addrbytelen));
memcpy(val + (sizeof(uint32_t) - (hs->addrbitlen >>3)), element + sizeof(uint32_t), hs->outbytelen - (sizeof(uint32_t) - (hs->addrbitlen >>3)));
//cout << "Element: "<< (hex) << (uint32_t) val[hs->outbytelen-1] << ", " << (uint32_t) (BYTE_SELECT_BITS_INV[hs->outbitlen & 0x03])
// << ", " << (uint32_t) (val[hs->outbytelen-1] & (BYTE_SELECT_BITS_INV[hs->outbitlen & 0x03]) )<< (dec) << " :";
val[hs->outbytelen-1] &= (BYTE_SELECT_BITS_INV[hs->outbitlen & 0x03]);
/*for(i = 0; i < hs->inbytelen; i++) {
cout << (hex) << (uint32_t) element[i];
}
cout << ", ";
for(i = 0; i < hs->outbytelen; i++) {
cout << (hex) << (uint32_t) val[i];
}
cout << (dec) << endl;*/
}
#endif
//cout << "Address for hfid = " << hfid << ": " << *address << ", L = " << L << ", R = " << R << endl;

View File

@ -103,7 +103,7 @@ void *gen_entries(void *ctx_tmp) {
uint32_t i, inbytelen, *address;
address = (uint32_t*) malloc(NUM_HASH_FUNCTIONS * sizeof(uint32_t));
tmpbuf = (uint8_t*) malloc(ceil_divide(ctx->hs->outbitlen, 8)); //for(i = 0; i < NUM_HASH_FUNCTIONS; i++) {
tmpbuf = (uint8_t*) calloc(ceil_divide(ctx->hs->outbitlen, 8), sizeof(uint8_t)); //for(i = 0; i < NUM_HASH_FUNCTIONS; i++) {
// tmpbuf[i] = (uint8_t*) malloc(ceil_divide(ctx->hs->outbitlen, 8));
//}
@ -132,7 +132,10 @@ inline void insert_element(sht_ctx* table, uint8_t* element, uint32_t* address,
}
tmp_bin->nvals++;
//TODO: or simply allocate a bigger block of memory: table->maxbinsize * 2, left out for efficiency reasons
assert(tmp_bin->nvals < table->maxbinsize);
if(tmp_bin->nvals == table->maxbinsize) {
increase_max_bin_size(table, hs->outbytelen);
}
//assert(tmp_bin->nvals < table->maxbinsize);
/*cout << "Inserted into bin: " << address << ": " << (hex);
for(uint32_t j = 0; j < table->outbytelen; j++) {
cout << (unsigned int) tmpbuf[j];
@ -142,7 +145,6 @@ inline void insert_element(sht_ctx* table, uint8_t* element, uint32_t* address,
}
}
void init_hash_table(sht_ctx* table, uint32_t nelements, hs_t* hs) {
uint32_t i;
@ -176,3 +178,15 @@ void free_hash_table(sht_ctx* table) {
//3. free the actual table
//free(table);
}
void increase_max_bin_size(sht_ctx* table, uint32_t valbytelen) {
uint32_t new_maxsize = table->maxbinsize * 2;
uint8_t* tmpvals;
for(uint32_t i = 0; i < table->nbins; i++) {
tmpvals = table->bins[i].values;
table->bins[i].values = (uint8_t*) malloc(new_maxsize * valbytelen);
memcpy(table->bins[i].values, tmpvals, table->bins[i].nvals * valbytelen);
free(tmpvals);
}
table->maxbinsize = new_maxsize;
}

View File

@ -48,6 +48,7 @@ uint8_t* simple_hashing(uint8_t* elements, uint32_t neles, uint32_t bitlen, uint
//routine for generating the entries, is invoked by the threads
void *gen_entries(void *ctx);
void init_hash_table(sht_ctx* table, uint32_t nelements, hs_t* hs);
void increase_max_bin_size(sht_ctx* table, uint32_t valbytelen);
void free_hash_table(sht_ctx* table);
inline void insert_element(sht_ctx* table, uint8_t* element, uint32_t* address, uint8_t* tmpbuf, hs_t* hs);

View File

@ -5,6 +5,243 @@
* Author: mzohner
*/
#include "test_psi.h"
int32_t main(int32_t argc, char** argv) {
string address="127.0.0.1";
uint32_t nelements, elebytelen, ntasks=1, nruns=1, symsecbits=128;
uint64_t rnd;
role_type role = (role_type) 0;
vector<CSocket> sockfd(ntasks);
uint16_t port=7766;
uint8_t* seed = (uint8_t*) malloc(AES_BYTES);
read_psi_test_options(&argc, &argv, &role, &nruns);
memcpy(seed, const_seed, AES_BYTES);
seed[0] = role;
crypto* crypt = new crypto(symsecbits, seed);
crypt->gen_rnd((uint8_t*) &rnd, sizeof(uint64_t));
srand((unsigned)rnd+time(0));
if(role == SERVER) {
listen(address.c_str(), port, sockfd.data(), ntasks);
} else {
for(uint32_t i = 0; i < ntasks; i++)
connect(address.c_str(), port, sockfd[i]);
}
for(uint32_t i = 0; i < nruns; i++) {
if(role == CLIENT) cout << "Running test on iteration " << i << std::flush;
nelements = rand() % (1<<12);
elebytelen = (rand() % 14) + 2;
test_psi_prot(role, sockfd.data(), nelements, elebytelen, crypt);
if(role == CLIENT) cout << endl;
}
cout << "All tests successfully passed" << endl;
}
uint32_t test_psi_prot(role_type role, CSocket* sock, uint32_t nelements,
uint32_t elebytelen, crypto* crypt) {
double epsilon=1.2;
uint32_t p_inter_size, n_inter_size, ot_inter_size, dh_inter_size, i, j, ntasks=1,
pnelements, nclients = 2;
uint8_t *elements, *pelements, *p_intersection, *n_intersection, *ot_intersection, *dh_intersection;
//if(protocol != TTP)
pnelements = set_up_parameters(role, nelements, &elebytelen, &elements, &pelements, sock[0], crypt);
p_inter_size = plaintext_intersect(nelements, pnelements, elebytelen, elements, pelements,
&p_intersection);
//cout << "Plaintext intersection computed " << endl;
if(role == CLIENT) cout << "." << std::flush;
n_inter_size = naivepsi(role, nelements, pnelements, elebytelen, elements, &n_intersection, crypt,
sock, ntasks);
//cout << "Naive intersection computed " << endl;
if(role == CLIENT) cout << "." << std::flush;
dh_inter_size = dhpsi(role, nelements, pnelements, elebytelen, elements, &dh_intersection, crypt,
sock, ntasks);
//cout << "DH intersection computed " << endl;
if(role == CLIENT) cout << "." << std::flush;
ot_inter_size = otpsi(role, nelements, pnelements, elebytelen, elements, &ot_intersection,
crypt, sock, ntasks, epsilon);
//cout << "OT intersection computed " << endl;
if(role == CLIENT) cout << "." << std::flush;
if(role == CLIENT) {
bool success = true;
success &= (p_inter_size == n_inter_size);
success &= (p_inter_size == dh_inter_size);
success &= (p_inter_size == ot_inter_size);
for(uint32_t i = 0; i < p_inter_size * elebytelen; i++) {
success &= (p_intersection[i] == n_intersection[i]);
success &= (p_intersection[i] == dh_intersection[i]);
success &= (p_intersection[i] == ot_intersection[i]);
}
if(!success) {
cout << "Error in tests for " << nelements << " and " << pnelements << " on " << elebytelen
<< " byte length: " << endl;
cout << "\t" << p_inter_size << " elements in verification intersection" << endl;
cout << "\t" << n_inter_size << " elements in naive intersection" << endl;
cout << "\t" << dh_inter_size << " elements in DH intersection" << endl;
cout << "\t" << ot_inter_size << " elements in OT intersection" << endl;
cout << "Plaintext intersection (" << p_inter_size << "): " << endl;
plot_set(p_intersection, p_inter_size, elebytelen);
cout << "Naive intersection (" << n_inter_size << "): " << endl;
plot_set(n_intersection, n_inter_size, elebytelen);
cout << "DH intersection (" << dh_inter_size << "): " << endl;
plot_set(dh_intersection, dh_inter_size, elebytelen);
cout << "OT intersection: (" << ot_inter_size << "): " << endl;
plot_set(ot_intersection, ot_inter_size, elebytelen);
}
if(p_inter_size > 0)
free(p_intersection);
if(n_inter_size > 0)
free(n_intersection);
if(dh_inter_size > 0)
free(dh_intersection);
if(ot_inter_size > 0)
free(ot_intersection);
assert(success);
}
free(elements);
free(pelements);
return 1;
}
void plot_set(uint8_t* set, uint32_t neles, uint32_t elebytelen) {
for(uint32_t i = 0; i < neles; i++) {
cout << i << ": ";
for(uint32_t j = 0; j < elebytelen; j++) {
cout << setw(2) << setfill('0') << (hex) << (uint32_t) set[i*elebytelen+j];
}
cout << (dec) << endl;
}
}
uint32_t plaintext_intersect(uint32_t myneles, uint32_t pneles, uint32_t bytelen, uint8_t* myelements,
uint8_t* pelements, uint8_t** result) {
uint32_t intersect_size = 0, i, j;
uint64_t tmpkey = 0;
uint8_t *tmpval;
uint8_t** matches = (uint8_t**) malloc(sizeof(uint8_t*) * min(myneles, pneles));
uint32_t keylen = min((uint32_t) bytelen, (uint32_t) 8);
bool success;
GHashTable *map= g_hash_table_new_full(g_int64_hash, g_int64_equal, NULL, NULL);
for(i = 0; i < myneles; i++) {
memcpy(&tmpkey, myelements+i*bytelen, keylen);
g_hash_table_insert(map,(void*) &tmpkey, myelements+i*bytelen);
}
for(i = 0; i < pneles; i++) {
memcpy(&tmpkey, pelements+i*bytelen, keylen);
if(g_hash_table_lookup_extended(map, (void*) &tmpkey, NULL, (void**) &tmpval)) {
success = true;
if(bytelen > 8) {
for(j = 8; j < bytelen && success; j++) {
if(tmpval[j] != pelements[i*bytelen+j])
success = false;
}
}
if(success) {
matches[intersect_size] = (uint8_t*) tmpval;
intersect_size++;
}
assert(intersect_size <= min(myneles, pneles));
}
}
*result = (uint8_t*) malloc(intersect_size * bytelen);
for(i = 0; i < intersect_size; i++) {
memcpy((*result) + i * bytelen, matches[i], bytelen);
}
free(matches);
return intersect_size;
}
uint32_t set_up_parameters(role_type role, uint32_t myneles, uint32_t* mybytelen,
uint8_t** elements, uint8_t** pelements, CSocket& sock, crypto* crypt) {
uint32_t pneles, nintersections;
//Exchange meta-information and equalize byte-length
sock.Send(&myneles, sizeof(uint32_t));
sock.Receive(&pneles, sizeof(uint32_t));
if(role == SERVER) {
sock.Send(mybytelen, sizeof(uint32_t));
} else {
sock.Receive(mybytelen, sizeof(uint32_t));
}
*elements = (uint8_t*) malloc(myneles * *mybytelen);
*pelements = (uint8_t*) malloc(pneles * *mybytelen);
crypt->gen_rnd(*elements, myneles * *mybytelen);
//Exchange elements for later check
if(role == SERVER) {
sock.Send(*elements, myneles * *mybytelen);
sock.Receive(*pelements, pneles * *mybytelen);
} else { //have the client use some of the servers values s.t. the intersection is not disjoint
sock.Receive(*pelements, pneles * *mybytelen);
nintersections = rand() % min(myneles, pneles);
for(uint32_t i = 0; i < nintersections; i++) {
memcpy(*elements + i * *mybytelen, *pelements + i * *mybytelen, *mybytelen);
}
sock.Send(*elements, myneles * *mybytelen);
}
//memset(*elements, 0x00, *mybytelen);
//memset(*pelements, 0x00, *mybytelen);
return pneles;
}
int32_t read_psi_test_options(int32_t* argcp, char*** argvp, role_type* role, uint32_t* nruns) {
uint32_t int_role;
parsing_ctx options[] = {{(void*) &int_role, T_NUM, 'r', "Role: 0/1", true, false},
{(void*) nruns, T_NUM, 't', "#of test iterations", false, false},
};
if(!parse_options(argcp, argvp, options, sizeof(options)/sizeof(parsing_ctx))) {
print_usage(argvp[0][0], options, sizeof(options)/sizeof(parsing_ctx));
exit(0);
}
assert(int_role < 2);
*role = (role_type) int_role;
return 1;
}

View File

@ -8,6 +8,28 @@
#ifndef TEST_PSI_H_
#define TEST_PSI_H_
#define SILENT_TESTS
#include <ctime>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <string>
#include "../pk-based/dh-psi.h"
#include "../ot-based/ot-psi.h"
#include "../server-aided/sapsi.h"
#include "../naive-hashing/naive-psi.h"
#include "../util/parse_options.h"
#include "../util/helpers.h"
uint32_t test_psi_prot(role_type role, CSocket* sock, uint32_t nelements,
uint32_t elebytelen, crypto* crypt);
uint32_t plaintext_intersect(uint32_t myneles, uint32_t pneles, uint32_t bytelen, uint8_t* myelements,
uint8_t* pelements, uint8_t** result);
uint32_t set_up_parameters(role_type role, uint32_t myneles, uint32_t* mybytelen,
uint8_t** elements, uint8_t** pelements, CSocket& sock, crypto* crypt);
int32_t read_psi_test_options(int32_t* argcp, char*** argvp, role_type* role, uint32_t* nruns);
void plot_set(uint8_t* set, uint32_t neles, uint32_t elebytelen);
#endif /* TEST_PSI_H_ */

View File

@ -52,15 +52,6 @@ uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t* elebyt
intersect_size = otpsi_client(eleptr, neles, nbins, pneles, internal_bitlen, maskbitlen, crypt_env,
sock, ntasks, &prf_state, &res_pos);
//std::sort(res_pos, res_pos+intersect_size);
//*result = (uint8_t**) malloc(intersect_size * sizeof(uint8_t*));
//*res_bytelen = (uint32_t*) malloc(intersect_size * sizeof(uint32_t));
/*for(i = 0; i < intersect_size; i++) {
(*res_bytelen)[i] = elebytelens[res_pos[i]];
(*result)[i] = (uint8_t*) malloc((*res_bytelen)[i]);
memcpy((*result)[i], elements[res_pos[i]], (*res_bytelen)[i]);
}*/
create_result_from_matches_var_bitlen(result, res_bytelen, elebytelens, elements, res_pos, intersect_size);
}
@ -71,12 +62,12 @@ uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t* elebyt
uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t elebitlen, uint8_t* elements,
uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t elebytelen, uint8_t* elements,
uint8_t** result, crypto* crypt_env, CSocket* sock, uint32_t ntasks, double epsilon,
bool detailed_timings) {
prf_state_ctx prf_state;
uint32_t maskbytelen, nbins, intersect_size, internal_bitlen, maskbitlen, *res_pos, i, elebytelen;
uint32_t maskbytelen, nbins, intersect_size, internal_bitlen, maskbitlen, *res_pos, i, elebitlen;
uint8_t *eleptr;
timeval t_start, t_end;
@ -84,16 +75,16 @@ uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t elebitl
maskbitlen = pad_to_multiple(crypt_env->get_seclvl().statbits + ceil_log2(neles) + ceil_log2(pneles), 8);
maskbytelen = ceil_divide(maskbitlen, 8);
elebytelen = ceil_divide(elebitlen, 8);
elebitlen = elebytelen * 8;
if(elebitlen > maskbitlen) {
//Hash elements into a smaller domain
eleptr = (uint8_t*) malloc(maskbytelen * neles);
domain_hashing(neles, elements, ceil_divide(elebitlen, 8), eleptr, maskbytelen, crypt_env);
domain_hashing(neles, elements, elebytelen, eleptr, maskbytelen, crypt_env);
internal_bitlen = maskbitlen;
#ifndef BATCH
cout << "Hashing " << neles << " elements with " << elebitlen << " bit-length into " <<
maskbitlen << " representation " << endl;
maskbitlen << " bit representation " << endl;
#endif
} else {
eleptr = elements;
@ -110,10 +101,11 @@ uint32_t otpsi(role_type role, uint32_t neles, uint32_t pneles, uint32_t elebitl
nbins = ceil(epsilon * neles);
intersect_size = otpsi_client(eleptr, neles, nbins, pneles, internal_bitlen, maskbitlen, crypt_env,
sock, ntasks, &prf_state, &res_pos);
*result = (uint8_t*) malloc(intersect_size * elebytelen);
for(i = 0; i < intersect_size; i++) {
memcpy((*result) + i * elebytelen, elements + res_pos[i] * elebytelen, elebytelen);
}
//*result = (uint8_t*) malloc(intersect_size * elebytelen);
//for(i = 0; i < intersect_size; i++) {
// memcpy((*result) + i * elebytelen, elements + res_pos[i] * elebytelen, elebytelen);
//}
create_result_from_matches_fixed_bitlen(result, elebytelen, elements, res_pos, intersect_size);
}
if(elebitlen > maskbitlen)
@ -176,27 +168,23 @@ uint32_t otpsi_client(uint8_t* elements, uint32_t neles, uint32_t nbins, uint32_
masks = (uint8_t*) malloc(neles * maskbytelen);
//Perform the OPRG execution
//cout << "otpsi client running ots" << endl;
oprg_client(hash_table, nbins, neles, nelesinbin, outbitlen, maskbitlen, crypt_env, sock, ntasks, masks);
if(DETAILED_TIMINGS) {
gettimeofday(&t_start, NULL);
}
/*uint64_t tmpmask = 0;
for(uint32_t i = 0; i < neles-1; i++) {
memcpy((uint8_t*) &tmpmask, masks + i*maskbytelen, maskbytelen);
cout << "Mask " << i << " : " << (hex) << tmpmask << (dec) << endl; //"intersection found at position " << tmpval[0] << " for key " << tmpbuf[0] << endl;
}*/
#ifdef TIMING
gettimeofday(&t_end, NULL);
cout << "Client: time for OPRG evaluation: " << getMillies(t_start, t_end) << " ms" << endl;
gettimeofday(&t_start, NULL);
#endif
/*#ifdef PRINT_BIN_CONTENT
#ifdef PRINT_BIN_CONTENT
cout << "Client masks: " << endl;
print_bin_content(masks, neles, maskbytelen, NULL, false);
#endif*/
#endif
//receive server masks
server_masks = (uint8_t*) malloc(NUM_HASH_FUNCTIONS * pneles * maskbytelen);
@ -265,6 +253,7 @@ uint32_t otpsi_client(uint8_t* elements, uint32_t neles, uint32_t nbins, uint32_
cout << "Time for intersecting:\t\t" << fixed << std::setprecision(2) <<
getMillies(t_start, t_end) << " ms" << endl;
}
free(masks);
free(hash_table);
free(nelesinbin);
@ -720,7 +709,7 @@ uint32_t otpsi_find_intersection(uint32_t** result, uint8_t* my_hashes,
} else {
keys_stored = 1;
tmp_hashbytelen = hashbytelen;
tmpkeys = (uint32_t*) malloc(my_neles * keys_stored * sizeof(uint32_t));
tmpkeys = (uint32_t*) malloc(my_neles * sizeof(uint32_t));
memcpy(tmpkeys, perm, my_neles * sizeof(uint32_t));
}
@ -760,7 +749,6 @@ uint32_t otpsi_find_intersection(uint32_t** result, uint8_t* my_hashes,
}
}
//TODO: workaround since the masks that are inserted into the hash table are too small and collisions occur
if(intersect_ctr > my_neles) {
cout << "more intersections than elements: " << intersect_ctr << " vs " << my_neles << endl;
intersect_ctr = my_neles;
@ -775,8 +763,8 @@ uint32_t otpsi_find_intersection(uint32_t** result, uint8_t* my_hashes,
//cout << "I found " << size_intersect << " intersecting elements" << endl;
free(matches);
free(map);
free(invperm);
free(tmpkeys);
return size_intersect;
}

View File

@ -91,6 +91,8 @@ static void create_result_from_matches_var_bitlen(uint8_t*** result, uint32_t**
*result = (uint8_t**) malloc(sizeof(uint8_t*) * intersect_size);
*resbytelens = (uint32_t*) malloc(sizeof(uint32_t) * intersect_size);
std::sort(matches, matches+intersect_size);
for(i = 0; i < intersect_size; i++) {
(*resbytelens)[i] = inbytelens[matches[i]];
(*result)[i] = (uint8_t*) malloc((*resbytelens)[i]);
@ -101,10 +103,12 @@ static void create_result_from_matches_var_bitlen(uint8_t*** result, uint32_t**
static void create_result_from_matches_fixed_bitlen(uint8_t** result, uint32_t inbytelen, uint8_t* inputs, uint32_t* matches,
uint32_t intersect_size) {
uint32_t i;
*result = (uint8_t*) malloc(sizeof(uint8_t) * intersect_size);
*result = (uint8_t*) malloc(inbytelen * intersect_size);
std::sort(matches, matches+intersect_size);
for(i = 0; i < intersect_size; i++) {
memcpy(result + i * inbytelen, inputs + matches[i] * inbytelen, inbytelen);
memcpy(*(result) + i * inbytelen, inputs + matches[i] * inbytelen, inbytelen);
}
}
@ -268,8 +272,8 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas
uint32_t hashbytelen, uint32_t* perm, uint32_t* matches) {
uint32_t* invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles);
uint64_t* tmpval;
uint64_t *tmpval, tmpkey = 0;
uint32_t mapbytelen = min((uint32_t) hashbytelen, (uint32_t) sizeof(uint64_t));
uint32_t size_intersect, i, intersect_ctr;
for(i = 0; i < neles; i++) {
@ -278,13 +282,13 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas
GHashTable *map= g_hash_table_new_full(g_int64_hash, g_int64_equal, NULL, NULL);
for(i = 0; i < neles; i++) {
g_hash_table_insert(map,(void*) ((uint64_t*) &(hashes[i*hashbytelen])), &(invperm[i]));
memcpy(&tmpkey, hashes + i*hashbytelen, mapbytelen);
g_hash_table_insert(map,(void*) &tmpkey, &(invperm[i]));
}
for(i = 0, intersect_ctr = 0; i < pneles; i++) {
if(g_hash_table_lookup_extended(map, (void*) ((uint64_t*) &(phashes[i*hashbytelen])),
NULL, (void**) &tmpval)) {
memcpy(&tmpkey, phashes+ i*hashbytelen, mapbytelen);
if(g_hash_table_lookup_extended(map, (void*) &tmpkey, NULL, (void**) &tmpval)) {
matches[intersect_ctr] = tmpval[0];
intersect_ctr++;
assert(intersect_ctr <= min(neles, pneles));

View File

@ -9,7 +9,7 @@
#define TYPEDEFS_H_
//#define DEBUG
//#define BATCH
#define BATCH
//#define TIMING
#include <iostream>