Integrated server-aided PSI protocol

This commit is contained in:
Michael Zohner 2015-05-22 14:46:53 +02:00
parent 89f56664ed
commit 922915697d
8 changed files with 203 additions and 158 deletions

View File

@ -16,7 +16,7 @@ int32_t psi_demonstrator(int32_t argc, char** argv) {
double epsilon=1.2; double epsilon=1.2;
uint64_t bytes_sent=0, bytes_received=0, mbfac; uint64_t bytes_sent=0, bytes_received=0, mbfac;
uint32_t nelements=0, elebytelen=16, symsecbits=128, intersect_size = 0, i, j, ntasks=1, uint32_t nelements=0, elebytelen=16, symsecbits=128, intersect_size = 0, i, j, ntasks=1,
pnelements, *elebytelens, *res_bytelens; pnelements, *elebytelens, *res_bytelens, nclients = 2;
uint16_t port=7766; uint16_t port=7766;
uint8_t **elements, **intersection; uint8_t **elements, **intersection;
bool detailed_timings=false; bool detailed_timings=false;
@ -32,7 +32,12 @@ int32_t psi_demonstrator(int32_t argc, char** argv) {
read_psi_demo_options(&argc, &argv, &role, &protocol, &filename, &address, &nelements, &detailed_timings); read_psi_demo_options(&argc, &argv, &role, &protocol, &filename, &address, &nelements, &detailed_timings);
if(role == SERVER) { if(role == SERVER) {
listen(address.c_str(), port, sockfd.data(), ntasks); if(protocol == TTP) {
sockfd.resize(nclients);
listen(address.c_str(), port, sockfd.data(), nclients);
}
else
listen(address.c_str(), port, sockfd.data(), ntasks);
} else { } else {
for(i = 0; i < ntasks; i++) for(i = 0; i < ntasks; i++)
connect(address.c_str(), port, sockfd[i]); connect(address.c_str(), port, sockfd[i]);
@ -41,13 +46,13 @@ int32_t psi_demonstrator(int32_t argc, char** argv) {
gettimeofday(&t_start, NULL); gettimeofday(&t_start, NULL);
//read in files and get elements and byte-length from there //read in files and get elements and byte-length from there
read_elements(&elements, &elebytelens, &nelements, filename); read_elements(&elements, &elebytelens, &nelements, filename);
if(detailed_timings) { if(detailed_timings) {
gettimeofday(&t_end, NULL); gettimeofday(&t_end, NULL);
} }
pnelements = exchange_information(nelements, elebytelen, symsecbits, ntasks, protocol, sockfd[0]); if(protocol != TTP)
pnelements = exchange_information(nelements, elebytelen, symsecbits, ntasks, protocol, sockfd[0]);
//cout << "Performing private set-intersection between " << nelements << " and " << pnelements << " element sets" << endl; //cout << "Performing private set-intersection between " << nelements << " and " << pnelements << " element sets" << endl;
if(detailed_timings) { if(detailed_timings) {
@ -62,7 +67,9 @@ int32_t psi_demonstrator(int32_t argc, char** argv) {
&crypto, sockfd.data(), ntasks); &crypto, sockfd.data(), ntasks);
break; break;
case TTP: case TTP:
///ttppsi(role, nelements, elebytelen, elements, &intersection, &crypto, sockfd.data(), nclients, cardinality); break; intersect_size = ttppsi(role, nelements, elebytelens, elements, &intersection, &res_bytelens,
&crypto, sockfd.data(), ntasks);
break;
case DH_ECC: case DH_ECC:
intersect_size = dhpsi(role, nelements, pnelements, elebytelens, elements, &intersection, &res_bytelens, &crypto, intersect_size = dhpsi(role, nelements, pnelements, elebytelens, elements, &intersection, &res_bytelens, &crypto,
sockfd.data(), ntasks); sockfd.data(), ntasks);

View File

@ -73,8 +73,7 @@ uint32_t naivepsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx
ectx.eles.nelements = neles; ectx.eles.nelements = neles;
ectx.eles.output = hashes; ectx.eles.output = hashes;
ectx.eles.perm = perm; ectx.eles.perm = perm;
ectx.sctx.symcrypt = crypt_env;
ectx.hctx.symcrypt = crypt_env;
run_task(ntasks, ectx, hash); run_task(ntasks, ectx, hash);

View File

@ -71,7 +71,7 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c
ectx.eles.nelements = neles; ectx.eles.nelements = neles;
ectx.eles.outbytelen = hash_bytes; ectx.eles.outbytelen = hash_bytes;
ectx.eles.perm = perm; ectx.eles.perm = perm;
ectx.hctx.symcrypt = crypt_env; ectx.sctx.symcrypt = crypt_env;
#ifdef DEBUG #ifdef DEBUG
@ -86,14 +86,14 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c
ectx.eles.outbytelen = fe_bytes; ectx.eles.outbytelen = fe_bytes;
ectx.eles.output = encrypted_eles; ectx.eles.output = encrypted_eles;
ectx.eles.hasvarbytelen = false; ectx.eles.hasvarbytelen = false;
ectx.ectx.field = field; ectx.actx.field = field;
ectx.ectx.exponent = exponent; ectx.actx.exponent = exponent;
ectx.ectx.sample = true; ectx.actx.sample = true;
#ifdef DEBUG #ifdef DEBUG
cout << "Hash and encrypting my elements" << endl; cout << "Hash and encrypting my elements" << endl;
#endif #endif
run_task(ntasks, ectx, encrypt); run_task(ntasks, ectx, asym_encrypt);
peles = (uint8_t*) malloc(sizeof(uint8_t) * pneles * fe_bytes); peles = (uint8_t*) malloc(sizeof(uint8_t) * pneles * fe_bytes);
@ -110,13 +110,13 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c
ectx.eles.fixedbytelen = fe_bytes; ectx.eles.fixedbytelen = fe_bytes;
ectx.eles.outbytelen = fe_bytes; ectx.eles.outbytelen = fe_bytes;
ectx.eles.hasvarbytelen = false; ectx.eles.hasvarbytelen = false;
ectx.ectx.exponent = exponent; ectx.actx.exponent = exponent;
ectx.ectx.sample = false; ectx.actx.sample = false;
#ifdef DEBUG #ifdef DEBUG
cout << "Encrypting partners elements" << endl; cout << "Encrypting partners elements" << endl;
#endif #endif
run_task(ntasks, ectx, encrypt); run_task(ntasks, ectx, asym_encrypt);
/* if only the cardinality should be computed, permute the elements randomly again. Otherwise don't permute */ /* if only the cardinality should be computed, permute the elements randomly again. Otherwise don't permute */
if(cardinality) { if(cardinality) {
@ -136,7 +136,7 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c
ectx.eles.outbytelen = hash_bytes; ectx.eles.outbytelen = hash_bytes;
ectx.eles.hasvarbytelen = false; ectx.eles.hasvarbytelen = false;
ectx.eles.perm = cardinality_perm; ectx.eles.perm = cardinality_perm;
ectx.hctx.symcrypt = crypt_env; ectx.sctx.symcrypt = crypt_env;
#ifdef DEBUG #ifdef DEBUG
cout << "Hashing elements" << endl; cout << "Hashing elements" << endl;

View File

@ -6,8 +6,8 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) {
CSocket* sockfds = socket;//(CSocket*) malloc(sizeof(CSocket) * nclients); CSocket* sockfds = socket;//(CSocket*) malloc(sizeof(CSocket) * nclients);
uint32_t* neles = (uint32_t*) malloc(sizeof(uint32_t) * nclients); uint32_t* neles = (uint32_t*) malloc(sizeof(uint32_t) * nclients);
uint8_t** csets = (uint8_t**) malloc(sizeof(uint8_t*) * nclients); uint8_t** csets = (uint8_t**) malloc(sizeof(uint8_t*) * nclients);
uint8_t* intersect; uint32_t temp, maskbytelen, intersectsize, minset, i, j;
uint32_t temp, maskbytelen, intersectsize, minset, i; CBitVector* intersection = new CBitVector[nclients];
#ifndef BATCH #ifndef BATCH
cout << "Connections with all " << nclients << " clients established" << endl; cout << "Connections with all " << nclients << " clients established" << endl;
@ -19,16 +19,17 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) {
sockfds[i].Receive(&temp, sizeof(uint32_t)); sockfds[i].Receive(&temp, sizeof(uint32_t));
if(i == 0) { maskbytelen = temp; minset = neles[i];} if(i == 0) { maskbytelen = temp; minset = neles[i];}
if(neles[i] < minset) minset = neles[i]; if(neles[i] < minset) minset = neles[i];
assert(maskbytelen == temp);
#ifndef BATCH #ifndef BATCH
cout << "Client " << i << " holds " << neles[i] << " elements of length " << (temp * 8) << "-bit" << endl; cout << "Client " << i << " holds " << neles[i] << " elements of length " << (temp * 8) << "-bit" << endl;
#endif #endif
intersection[i].ResizeinBytes(ceil_divide(neles[i], 8));
intersection[i].Reset();
assert(maskbytelen == temp);
} }
#ifndef BATCH #ifndef BATCH
cout <<"Receiving the client's elements" << endl; cout <<"Receiving the client's elements" << endl;
#endif #endif
/* Allocate sufficient size for the intersecting elements */
intersect = (uint8_t*) malloc(sizeof(uint8_t*) * minset * maskbytelen);
/* Receive the permuted and masked sets of all clients */ /* Receive the permuted and masked sets of all clients */
for(i = 0; i < nclients; i++) { for(i = 0; i < nclients; i++) {
@ -40,7 +41,10 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) {
cout << "Computing intersection for the clients" << endl; cout << "Computing intersection for the clients" << endl;
#endif #endif
/* Compute Intersection */ /* Compute Intersection */
intersectsize = compute_intersection(nclients, neles, csets, intersect, maskbytelen); intersectsize = compute_intersection(nclients, neles, csets, intersection, maskbytelen);
/* Enter at which position an intersection was found */
#ifndef BATCH #ifndef BATCH
cout << "sending all " << intersectsize << " intersecting elements to the clients" << endl; cout << "sending all " << intersectsize << " intersecting elements to the clients" << endl;
#endif #endif
@ -48,7 +52,7 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) {
for(i = 0; i < nclients; i++) { for(i = 0; i < nclients; i++) {
sockfds[i].Send(&intersectsize, sizeof(uint32_t)); sockfds[i].Send(&intersectsize, sizeof(uint32_t));
if(!cardinality) if(!cardinality)
sockfds[i].Send(intersect, intersectsize * maskbytelen); sockfds[i].Send(intersection[i].GetArr(), ceil_divide(neles[i], 8));
} }
/* Cleanup */ /* Cleanup */
@ -60,7 +64,7 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) {
* for the n-party case a BF-based approach makes more sense. * for the n-party case a BF-based approach makes more sense.
*/ */
//TODO currently only works for 128 bit masks //TODO currently only works for 128 bit masks
uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, uint8_t* intersect, uint32_t entrybytelen) { uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, CBitVector* intersection, uint32_t entrybytelen) {
// Create the GHashTable // Create the GHashTable
GHashTable *map = NULL, *tmpmap = NULL; GHashTable *map = NULL, *tmpmap = NULL;
GHashTableIter iter; GHashTableIter iter;
@ -72,9 +76,10 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset
NULL // cleanup value NULL // cleanup value
); );
uint32_t i, j, intersectsize, ctr = 0; uint32_t i, j, intersectsize, ctr = 0, k;
uint64_t* tmpval = (uint64_t*) malloc(sizeof(uint64_t)); uint64_t* tmpval;
uint64_t* tmpkey = (uint64_t*) malloc(sizeof(uint64_t)); uint64_t* tmpkey = (uint64_t*) malloc(sizeof(uint64_t));
uint64_t* query;
#ifndef BATCH #ifndef BATCH
cout << "Inserting the items into the hash table " << endl; cout << "Inserting the items into the hash table " << endl;
#endif #endif
@ -83,7 +88,10 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset
#ifdef DEBUG #ifdef DEBUG
cout << "Inserted item: " << (hex) << ((uint64_t*) csets[0])[2*i] << " "<< ((uint64_t*) csets[0])[2*i+1] << (dec) << endl; cout << "Inserted item: " << (hex) << ((uint64_t*) csets[0])[2*i] << " "<< ((uint64_t*) csets[0])[2*i+1] << (dec) << endl;
#endif #endif
g_hash_table_insert(map,(void*) &((uint64_t*)csets[0])[2*i], &(((uint64_t*)csets[0])[2*i+1])); tmpval = (uint64_t*) malloc(2*sizeof(uint64_t));
tmpval[0] = (((uint64_t*)csets[0])[2*i+1]);
tmpval[1] = i;
g_hash_table_insert(map,(void*) &((uint64_t*)csets[0])[2*i], tmpval);//&(((uint64_t*)csets[0])[2*i+1]));
} }
#ifdef DEBUG #ifdef DEBUG
g_hash_table_foreach( map, printKeyValue, NULL ); g_hash_table_foreach( map, printKeyValue, NULL );
@ -108,11 +116,17 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset
cout << "Checking for Key: " << (hex) << ((uint64_t*) csets[i])[2*j] << " "<< ((uint64_t*) csets[i])[2*j+1] << (dec) << endl; cout << "Checking for Key: " << (hex) << ((uint64_t*) csets[i])[2*j] << " "<< ((uint64_t*) csets[i])[2*j+1] << (dec) << endl;
#endif #endif
if(g_hash_table_lookup_extended(map, (void*) &(((uint64_t*)csets[i])[2*j]), if(g_hash_table_lookup_extended(map, (void*) &(((uint64_t*)csets[i])[2*j]),
NULL, (void**) &tmpval) && (*tmpval == ((uint64_t*)csets[i])[2*j+1])) { NULL, (void**) &query) && (*query == ((uint64_t*)csets[i])[2*j+1])) {
#ifdef DEBUG #ifdef DEBUG
cout << "Key was found" << endl; cout << "Key was found" << endl;
#endif #endif
g_hash_table_insert(tmpmap,(void*) &(((uint64_t*)csets[i])[2*j]),&(((uint64_t*)csets[i])[2*j+1])); tmpval = (uint64_t*) malloc((i+2)*sizeof(uint64_t));
tmpval[0] = (((uint64_t*)csets[i])[2*j+1]);
for(k = 1; k < i+1; k++) {
tmpval[k] = query[k];
}
tmpval[i+1] = j;
g_hash_table_insert(tmpmap,(void*) &(((uint64_t*)csets[i])[2*j]), tmpval);//&(((uint64_t*)csets[i])[2*j+1]));
} else { } else {
#ifdef DEBUG #ifdef DEBUG
@ -143,8 +157,9 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset
#ifdef DEBUG #ifdef DEBUG
cout << (hex) << tmpkey[0] << " " << tmpval[0] << (dec)<< endl; cout << (hex) << tmpkey[0] << " " << tmpval[0] << (dec)<< endl;
#endif #endif
((uint64_t*) intersect)[ctr++] = tmpkey[0]; for(i = 0; i < nclients; i++) {
((uint64_t*) intersect)[ctr++] = tmpval[0]; intersection[i].SetBit(tmpval[i+1], 1);
}
} }
gettimeofday(&end, NULL); gettimeofday(&end, NULL);
#ifdef TIMING #ifdef TIMING
@ -162,69 +177,54 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset
} }
uint32_t client_routine(uint32_t neles, uint32_t elebytelen, uint8_t* elements, uint32_t client_routine(uint32_t neles, task_ctx ectx, uint32_t* matches,
uint8_t** result, crypto* crypt, CSocket* socket, bool cardinality) { crypto* crypt_env, CSocket* socket, uint32_t ntasks, bool cardinality) {
uint32_t maskbytelen = 16, intersectsize, i, j; uint32_t maskbytelen, intersectsize, i, matchctr;
uint8_t* masks = (uint8_t*) malloc(sizeof(uint8_t) * neles * maskbytelen);
uint8_t* intersect = (uint8_t*) malloc(sizeof(uint8_t) * neles * maskbytelen);
uint32_t* perm;
uint32_t* invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles);
uint32_t* tmpval = (uint32_t*) malloc(sizeof(uint32_t));
GHashTable *map;
//crypto crypto(symsecbits, (uint8_t*) const_seed);
// cout << "Starting client with " << neles << " elements of " << (8*elebytelen) << "-bit length with server "
// << address << ":" << port << endl;
CSocket* sockfd = socket;
//connect(address, port, sockfd);
sockfd->Send((uint8_t*) &neles, sizeof(uint32_t));
sockfd->Send((uint8_t*) &maskbytelen, sizeof(uint32_t));
perm = mask_and_permute_elements(neles, elebytelen, elements, maskbytelen, masks, crypt->get_seclvl().symbits, crypt); uint8_t* masks;
uint32_t *perm, *invperm;
CBitVector inIntersection(neles);
sockfd->Send(masks, maskbytelen * neles); //TODO works only fine for equally sized sets, if one set is bigger than the other, this will fail!
maskbytelen = 16;//ceil_divide(crypt_env->get_seclvl().statbits + 2*ceil_log2(neles), 8);
if(!cardinality) { masks = (uint8_t*) malloc(sizeof(uint8_t) * neles * maskbytelen);
for(i = 0; i < neles; i++) { perm = (uint32_t*) malloc(sizeof(uint32_t) * neles);
invperm[perm[i]] = i; invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles);
}
map= g_hash_table_new_full(g_int64_hash, g_int64_equal, NULL, NULL); /* Generate the random permutation the elements */
for(i = 0; i < neles; i++) { crypt_env->gen_rnd_perm(perm, neles);
g_hash_table_insert(map,(void*) &((uint64_t*)masks)[2*i], &(invperm[i]));
} socket->Send((uint8_t*) &neles, sizeof(uint32_t));
socket->Send((uint8_t*) &maskbytelen, sizeof(uint32_t));
ectx.eles.outbytelen = maskbytelen,
ectx.eles.nelements = neles;
ectx.eles.output = masks;
ectx.eles.perm = perm;
ectx.sctx.symcrypt = crypt_env;
ectx.sctx.keydata = (uint8_t*) const_seed;
run_task(ntasks, ectx, hash);
socket->Send(masks, maskbytelen * neles);
socket->Receive(&intersectsize, sizeof(uint32_t));
for(i = 0; i < neles; i++) {
invperm[perm[i]] = i;
} }
sockfd->Receive(&intersectsize, sizeof(uint32_t));
if(!cardinality) { if(!cardinality) {
sockfd->Receive(intersect, maskbytelen * intersectsize); socket->Receive(inIntersection.GetArr(), ceil_divide(neles, 8));
#ifdef DEBUG for(i = 0, matchctr = 0; i < neles; i++) {
cout << "The intersection contains " << intersectsize << " elements: " << endl; if(inIntersection.GetBit(i)) {
for(i = 0; i < intersectsize; i++) { matches[matchctr] = invperm[i];
cout << (hex) << ((uint64_t*)intersect)[2*i] << " " << ((uint64_t*)intersect)[2*i+1] << (dec) << endl; matchctr++;
}
} }
#endif
*result = (uint8_t*) malloc(elebytelen * intersectsize);
//uint8_t* tmpbuf = (uint8_t*) malloc(maskbytelen);
for(i = 0; i < intersectsize; i++) {
g_hash_table_lookup_extended(map, (void*) &(((uint64_t*)intersect)[2*i]), NULL, (void**) &tmpval);
memcpy((*result) + i * elebytelen, elements + tmpval[0] * elebytelen, elebytelen);
//crypto.decrypt(tmpbuf, intersect+i*maskbytelen, maskbytelen);
//memcpy((*result) + i * elebytelen, tmpbuf, elebytelen);
#ifdef DEBUG
cout << ((uint32_t*) elements)[tmpval[0]] << ", ";
#endif
}
#ifdef DEBUG
cout << endl;
#endif
} }
free(perm); free(perm);
@ -233,51 +233,52 @@ uint32_t client_routine(uint32_t neles, uint32_t elebytelen, uint8_t* elements,
return intersectsize; return intersectsize;
} }
void printKeyValue( gpointer key, gpointer value, gpointer userData ) {
uint64_t realKey = *((uint64_t*)key);
uint64_t realValue = *((uint64_t*)value);
cout << (hex) << realKey << ": " << realValue << (dec) << endl;
return;
}
uint32_t* mask_and_permute_elements(uint32_t neles, uint32_t elebytelen, uint8_t*
elements, uint32_t maskbytelen, uint8_t* masks, uint32_t symsecbits, crypto* crypto) {
uint32_t* perm = (uint32_t*) malloc(sizeof(uint32_t) * neles);
uint8_t* maskpermptr;
uint32_t i;
//Get random permutation
crypto->gen_rnd_perm(perm, neles);
//crypto->seed_aes_enc(client_psk);
//Hash and permute all elements
for(i = 0; i < neles; i++) {
//cout << "Performing encryption for " << i << "-th element " << ((uint32_t*) elements)[i] << ": ";
maskpermptr = masks + perm[i] * maskbytelen;
crypto->hash(maskpermptr, maskbytelen, elements + i * elebytelen, elebytelen);
//crypto->encrypt(maskpermptr, elements+i*elebytelen, elebytelen);
//cout <<(hex)<< ((uint64_t*) maskpermptr)[0] << ((uint64_t*) maskpermptr)[1] << (dec) << endl;
#ifdef DEBUG
cout << "Resulting hash for element " << ((uint32_t*)elements)[i] << ": " << (hex) << ((uint64_t*) maskpermptr)[0] <<
" " << ((uint64_t*) maskpermptr)[1] << (dec) << endl;
#endif
}
//free(perm);
return perm;
}
uint32_t ttppsi(role_type role, uint32_t neles, uint32_t elebytelen, uint8_t* elements, uint32_t ttppsi(role_type role, uint32_t neles, uint32_t elebytelen, uint8_t* elements,
uint8_t** intersection, crypto* crypt, CSocket* sockets, uint32_t nclients, bool cardinality) { uint8_t** result, crypto* crypt, CSocket* sockets, uint32_t ntasks, uint32_t nclients, bool cardinality) {
if(role == 0) { //Start the server if(role == 0) { //Start the server
//TODO maybe rerun infinitely //TODO maybe rerun infinitely
server_routine(nclients, sockets, cardinality); server_routine(nclients, sockets, cardinality);
return 0; return 0;
} else { //Start clients } else { //Start clients
return client_routine(neles, elebytelen, elements, intersection, crypt, sockets, cardinality); task_ctx ectx;
ectx.eles.input1d = elements;
ectx.eles.fixedbytelen = elebytelen;
ectx.eles.hasvarbytelen = false;
uint32_t* matches = (uint32_t*) malloc(sizeof(uint32_t) * neles);
uint32_t intersect_size = client_routine(neles, ectx, matches, crypt, sockets, ntasks, cardinality);
create_result_from_matches_fixed_bitlen(result, elebytelen, elements, matches, intersect_size);
free(matches);
return intersect_size;
}
}
uint32_t ttppsi(role_type role, uint32_t neles, uint32_t* elebytelens, uint8_t** elements,
uint8_t*** result, uint32_t** resbytelens, crypto* crypt, CSocket* sockets,
uint32_t ntasks, uint32_t nclients, bool cardinality) {
if(role == 0) { //Start the server
//TODO maybe rerun infinitely
server_routine(nclients, sockets, cardinality);
return 0;
} else { //Start clients
task_ctx ectx;
ectx.eles.input2d = elements;
ectx.eles.varbytelens = elebytelens;
ectx.eles.hasvarbytelen = true;
uint32_t* matches = (uint32_t*) malloc(sizeof(uint32_t) * neles);
uint32_t intersect_size = client_routine(neles, ectx, matches, crypt, sockets, ntasks, cardinality);
create_result_from_matches_var_bitlen(result, resbytelens, elebytelens, elements, matches, intersect_size);
free(matches);
return intersect_size;
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* shpsi.h * sapsi.h
* *
* Created on: Jul 1, 2014 * Created on: Jul 1, 2014
* Author: mzohner * Author: mzohner
@ -13,12 +13,18 @@
#include "../util/socket.h" #include "../util/socket.h"
#include "../util/typedefs.h" #include "../util/typedefs.h"
#include "../util/connection.h" #include "../util/connection.h"
#include "../util/helpers.h"
#include "../util/cbitvector.h"
/* start both roles*/ /* start both roles*/
uint32_t ttppsi(role_type role, uint32_t neles, uint32_t elebytelen, uint8_t* elements, uint32_t ttppsi(role_type role, uint32_t neles, uint32_t elebytelen, uint8_t* elements,
uint8_t** intersection, crypto* crypt, CSocket* socket, uint32_t nclients = 0, bool cardinality=false); uint8_t** intersection, crypto* crypt, CSocket* socket, uint32_t ntasks, uint32_t nclients = 2, bool cardinality=false);
uint32_t ttppsi(role_type role, uint32_t neles, uint32_t* elebytelens, uint8_t** elements,
uint8_t*** result, uint32_t** resbytelens, crypto* crypt, CSocket* sockets,
uint32_t ntasks, uint32_t nclients = 2, bool cardinality = false);
/* /*
* Params: * Params:
@ -30,8 +36,8 @@ uint32_t ttppsi(role_type role, uint32_t neles, uint32_t elebytelen, uint8_t* el
* port: port that the server is listening on * port: port that the server is listening on
* return: number of intersecting elements * return: number of intersecting elements
*/ */
uint32_t client_routine(uint32_t neles, uint32_t elebytelen, uint8_t* elements, uint32_t client_routine(uint32_t neles, task_ctx ectx, uint32_t* matches, crypto* crypt,
uint8_t** intersection, crypto* crypt, CSocket* socket, bool cardinality); CSocket* socket, uint32_t ntasks, bool cardinality);
/* /*
* Mask and permute the elements using the pre-shared key * Mask and permute the elements using the pre-shared key
@ -48,7 +54,7 @@ uint32_t* mask_and_permute_elements(uint32_t neles, uint32_t elebytelen, uint8_t
*/ */
void server_routine(uint32_t nclients, CSocket* socket, bool cardinality); void server_routine(uint32_t nclients, CSocket* socket, bool cardinality);
uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, uint8_t* intersect, uint32_t entrybytelen); uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, CBitVector* intersection, uint32_t entrybytelen);
void printKeyValue( gpointer key, gpointer value, gpointer userData ); void printKeyValue( gpointer key, gpointer value, gpointer userData );

View File

@ -120,13 +120,13 @@ void crypto::gen_rnd_uniform(uint8_t* resbuf, uint64_t mod) {
void crypto::encrypt(AES_KEY_CTX* enc_key, uint8_t* resbuf, uint8_t* inbuf, uint32_t ninbytes) { void crypto::encrypt(AES_KEY_CTX* enc_key, uint8_t* resbuf, uint8_t* inbuf, uint32_t ninbytes) {
int32_t dummy; int32_t dummy;
EVP_EncryptUpdate(enc_key, resbuf, &dummy, inbuf, ninbytes); EVP_EncryptUpdate(enc_key, resbuf, &dummy, inbuf, ninbytes);
//EVP_EncryptFinal_ex(enc_key, resbuf, &dummy); EVP_EncryptFinal_ex(enc_key, resbuf, &dummy);
} }
void crypto::decrypt(AES_KEY_CTX* dec_key, uint8_t* resbuf, uint8_t* inbuf, uint32_t ninbytes) { void crypto::decrypt(AES_KEY_CTX* dec_key, uint8_t* resbuf, uint8_t* inbuf, uint32_t ninbytes) {
int32_t dummy; int32_t dummy;
//cout << "inbuf = " << (hex) << ((uint64_t*) inbuf)[0] << ((uint64_t*) inbuf)[1] << (dec) << endl; //cout << "inbuf = " << (hex) << ((uint64_t*) inbuf)[0] << ((uint64_t*) inbuf)[1] << (dec) << endl;
EVP_DecryptUpdate(dec_key, resbuf, &dummy, inbuf, ninbytes); EVP_DecryptUpdate(dec_key, resbuf, &dummy, inbuf, ninbytes);
//EVP_DecryptFinal_ex(dec_key, resbuf, &dummy); EVP_DecryptFinal_ex(dec_key, resbuf, &dummy);
//cout << "outbuf = " << (hex) << ((uint64_t*) resbuf)[0] << ((uint64_t*) resbuf)[1] << (dec) << " (" << dummy << ")" << endl; //cout << "outbuf = " << (hex) << ((uint64_t*) resbuf)[0] << ((uint64_t*) resbuf)[1] << (dec) << " (" << dummy << ")" << endl;
} }
@ -212,6 +212,10 @@ void crypto::hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t
hash_routine(resbuf, noutbytes, inbuf, ninbytes, sha_hash_buf); hash_routine(resbuf, noutbytes, inbuf, ninbytes, sha_hash_buf);
} }
void crypto::hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes, uint8_t* tmpbuf) {
hash_routine(resbuf, noutbytes, inbuf, ninbytes, tmpbuf);
}
//A fixed-key hashing scheme that uses AES, should not be used for real hashing, hashes to AES_BYTES bytes //A fixed-key hashing scheme that uses AES, should not be used for real hashing, hashes to AES_BYTES bytes
void crypto::fixed_key_aes_hash(AES_KEY_CTX* aes_key, uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes) { void crypto::fixed_key_aes_hash(AES_KEY_CTX* aes_key, uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes) {

View File

@ -62,6 +62,7 @@ public:
//Hash routines //Hash routines
void hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes); void hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes);
void hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes, uint8_t* tmpbuf);
void hash_ctr(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes, uint32_t ctr); void hash_ctr(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes, uint32_t ctr);
void fixed_key_aes_hash(AES_KEY_CTX* aes_key, uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes); void fixed_key_aes_hash(AES_KEY_CTX* aes_key, uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes);
void fixed_key_aes_hash_ctr(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes); void fixed_key_aes_hash_ctr(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes);

View File

@ -31,12 +31,12 @@ struct element_ctx {
bool hasvarbytelen; bool hasvarbytelen;
}; };
struct hash_ctx { struct sym_ctx {
crypto* symcrypt; crypto* symcrypt;
uint8_t* keydata;
}; };
struct encrypt_ctx { struct asym_ctx {
num* exponent; num* exponent;
pk_crypto* field; pk_crypto* field;
bool sample; bool sample;
@ -45,8 +45,8 @@ struct encrypt_ctx {
struct task_ctx { struct task_ctx {
element_ctx eles; element_ctx eles;
union { union {
hash_ctx hctx; sym_ctx sctx;
encrypt_ctx ectx; asym_ctx actx;
}; };
}; };
@ -108,20 +108,20 @@ static void create_result_from_matches_fixed_bitlen(uint8_t** result, uint32_t i
} }
} }
static void *encrypt(void* context) { static void *asym_encrypt(void* context) {
#ifdef DEBUG #ifdef DEBUG
cout << "Encryption task started" << endl; cout << "Encryption task started" << endl;
#endif #endif
pk_crypto* field = ((task_ctx*) context)->ectx.field; pk_crypto* field = ((task_ctx*) context)->actx.field;
element_ctx electx = ((task_ctx*) context)->eles; element_ctx electx = ((task_ctx*) context)->eles;
num* e = ((task_ctx*) context)->ectx.exponent; num* e = ((task_ctx*) context)->actx.exponent;
fe* tmpfe = field->get_fe(); fe* tmpfe = field->get_fe();
uint8_t *inptr=electx.input1d, *outptr=electx.output; uint8_t *inptr=electx.input1d, *outptr=electx.output;
uint32_t i; uint32_t i;
for(i = 0; i < electx.nelements; i++, inptr+=electx.fixedbytelen, outptr+=electx.outbytelen) { for(i = 0; i < electx.nelements; i++, inptr+=electx.fixedbytelen, outptr+=electx.outbytelen) {
if(((task_ctx*) context)->ectx.sample) { if(((task_ctx*) context)->actx.sample) {
tmpfe->sample_fe_from_bytes(inptr, electx.fixedbytelen); tmpfe->sample_fe_from_bytes(inptr, electx.fixedbytelen);
//cout << "Mapped " << ((uint32_t*) inptr)[0] << " to "; //cout << "Mapped " << ((uint32_t*) inptr)[0] << " to ";
} else { } else {
@ -135,29 +135,71 @@ static void *encrypt(void* context) {
return 0; return 0;
} }
static void *hash(void* context) { static void *sym_encrypt(void* context) {
#ifdef DEBUG #ifdef DEBUG
cout << "Hashing thread started" << endl; cout << "Hashing thread started" << endl;
#endif #endif
hash_ctx hdata = ((task_ctx*) context)->hctx; sym_ctx hdata = ((task_ctx*) context)->sctx;
element_ctx electx = ((task_ctx*) context)->eles; element_ctx electx = ((task_ctx*) context)->eles;
crypto* crypt_env = hdata.symcrypt; crypto* crypt_env = hdata.symcrypt;
AES_KEY_CTX aes_key;
//cout << "initializing key" << endl;
crypt_env->init_aes_key(&aes_key, hdata.keydata);
//cout << "initialized key" << endl;
uint8_t* aes_buf = (uint8_t*) malloc(AES_BYTES);
uint32_t* perm = electx.perm; uint32_t* perm = electx.perm;
uint32_t i; uint32_t i;
if(electx.hasvarbytelen) { if(electx.hasvarbytelen) {
uint8_t **inptr = electx.input2d; uint8_t **inptr = electx.input2d;
for(i = electx.startelement; i < electx.endelement; i++) { for(i = electx.startelement; i < electx.endelement; i++) {
crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr[i], electx.varbytelens[i]); //crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr[i], electx.varbytelens[i]);
//cout << "encrypting i = " << i << ", perm = " << perm [i] << ", outbytelen = " << electx.outbytelen << endl;
crypt_env->encrypt(&aes_key, aes_buf, inptr[i], electx.varbytelens[i]);
memcpy(electx.output+perm[i]*electx.outbytelen, aes_buf, electx.outbytelen);
} }
} else { } else {
uint8_t *inptr = electx.input1d; uint8_t *inptr = electx.input1d;
for(i = electx.startelement; i < electx.endelement; i++, inptr+=electx.fixedbytelen) { for(i = electx.startelement; i < electx.endelement; i++, inptr+=electx.fixedbytelen) {
crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr, electx.fixedbytelen); //crypt_env->hash(&aes_key, electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr, electx.fixedbytelen);
crypt_env->encrypt(&aes_key, aes_buf, inptr, electx.fixedbytelen);
memcpy(electx.output+perm[i]*electx.outbytelen, aes_buf, electx.outbytelen);
} }
} }
//cout << "Returning" << endl;
//free(aes_buf);
return 0;
}
static void *hash(void* context) {
#ifdef DEBUG
cout << "Hashing thread started" << endl;
#endif
sym_ctx hdata = ((task_ctx*) context)->sctx;
element_ctx electx = ((task_ctx*) context)->eles;
crypto* crypt_env = hdata.symcrypt;
uint32_t* perm = electx.perm;
uint32_t i;
uint8_t* tmphashbuf = (uint8_t*) malloc(crypt_env->get_hash_bytes());
if(electx.hasvarbytelen) {
uint8_t **inptr = electx.input2d;
for(i = electx.startelement; i < electx.endelement; i++) {
crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr[i], electx.varbytelens[i], tmphashbuf);
}
} else {
uint8_t *inptr = electx.input1d;
for(i = electx.startelement; i < electx.endelement; i++, inptr+=electx.fixedbytelen) {
crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr, electx.fixedbytelen, tmphashbuf);
}
}
free(tmphashbuf);
return 0; return 0;
} }
@ -198,8 +240,6 @@ static void run_task(uint32_t nthreads, task_ctx context, void* (*func)(void*) )
neles_cur = min(context.eles.nelements - electr, neles_thread); neles_cur = min(context.eles.nelements - electr, neles_thread);
memcpy(contexts + i, &context, sizeof(task_ctx)); memcpy(contexts + i, &context, sizeof(task_ctx));
contexts[i].eles.nelements = neles_cur; contexts[i].eles.nelements = neles_cur;
//contexts[i].eles.input = context.eles.input + (context.eles.inbytelen * electr);
//contexts[i].eles.output = context.eles.output + (context.eles.outbytelen * electr);
contexts[i].eles.startelement = electr; contexts[i].eles.startelement = electr;
contexts[i].eles.endelement = electr + neles_cur; contexts[i].eles.endelement = electr + neles_cur;
electr += neles_cur; electr += neles_cur;
@ -228,7 +268,6 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas
uint32_t hashbytelen, uint32_t* perm, uint32_t* matches) { uint32_t hashbytelen, uint32_t* perm, uint32_t* matches) {
uint32_t* invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles); uint32_t* invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles);
//uint32_t* matches = (uint32_t*) malloc(sizeof(uint32_t) * neles);
uint64_t* tmpval; uint64_t* tmpval;
uint32_t size_intersect, i, intersect_ctr; uint32_t size_intersect, i, intersect_ctr;
@ -236,17 +275,12 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas
for(i = 0; i < neles; i++) { for(i = 0; i < neles; i++) {
invperm[perm[i]] = i; invperm[perm[i]] = i;
} }
//cout << "My number of elements. " << neles << ", partner number of elements: " << pneles << ", maskbytelen: " << hashbytelen << endl;
GHashTable *map= g_hash_table_new_full(g_int64_hash, g_int64_equal, NULL, NULL); GHashTable *map= g_hash_table_new_full(g_int64_hash, g_int64_equal, NULL, NULL);
for(i = 0; i < neles; i++) { for(i = 0; i < neles; i++) {
g_hash_table_insert(map,(void*) ((uint64_t*) &(hashes[i*hashbytelen])), &(invperm[i])); g_hash_table_insert(map,(void*) ((uint64_t*) &(hashes[i*hashbytelen])), &(invperm[i]));
} }
//for(i = 0; i < pneles; i++) {
// ((uint64_t*) &(phashes[i*hashbytelen]))[0]++;
//}
for(i = 0, intersect_ctr = 0; i < pneles; i++) { for(i = 0, intersect_ctr = 0; i < pneles; i++) {
if(g_hash_table_lookup_extended(map, (void*) ((uint64_t*) &(phashes[i*hashbytelen])), if(g_hash_table_lookup_extended(map, (void*) ((uint64_t*) &(phashes[i*hashbytelen])),
@ -259,14 +293,7 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas
size_intersect = intersect_ctr; size_intersect = intersect_ctr;
//result = (uint8_t**) malloc(sizeof(uint8_t*));
//(*result) = (uint8_t*) malloc(sizeof(uint8_t) * size_intersect * elebytelen);
//for(i = 0; i < size_intersect; i++) {
// memcpy((*result) + i * elebytelen, elements + matches[i] * elebytelen, elebytelen);
//}
free(invperm); free(invperm);
//free(matches);
return size_intersect; return size_intersect;
} }