Integrated server-aided PSI protocol

This commit is contained in:
Michael Zohner 2015-05-22 14:46:53 +02:00
parent 89f56664ed
commit 922915697d
8 changed files with 203 additions and 158 deletions

View File

@ -16,7 +16,7 @@ int32_t psi_demonstrator(int32_t argc, char** argv) {
double epsilon=1.2;
uint64_t bytes_sent=0, bytes_received=0, mbfac;
uint32_t nelements=0, elebytelen=16, symsecbits=128, intersect_size = 0, i, j, ntasks=1,
pnelements, *elebytelens, *res_bytelens;
pnelements, *elebytelens, *res_bytelens, nclients = 2;
uint16_t port=7766;
uint8_t **elements, **intersection;
bool detailed_timings=false;
@ -32,7 +32,12 @@ int32_t psi_demonstrator(int32_t argc, char** argv) {
read_psi_demo_options(&argc, &argv, &role, &protocol, &filename, &address, &nelements, &detailed_timings);
if(role == SERVER) {
listen(address.c_str(), port, sockfd.data(), ntasks);
if(protocol == TTP) {
sockfd.resize(nclients);
listen(address.c_str(), port, sockfd.data(), nclients);
}
else
listen(address.c_str(), port, sockfd.data(), ntasks);
} else {
for(i = 0; i < ntasks; i++)
connect(address.c_str(), port, sockfd[i]);
@ -41,13 +46,13 @@ int32_t psi_demonstrator(int32_t argc, char** argv) {
gettimeofday(&t_start, NULL);
//read in files and get elements and byte-length from there
read_elements(&elements, &elebytelens, &nelements, filename);
if(detailed_timings) {
gettimeofday(&t_end, NULL);
}
pnelements = exchange_information(nelements, elebytelen, symsecbits, ntasks, protocol, sockfd[0]);
if(protocol != TTP)
pnelements = exchange_information(nelements, elebytelen, symsecbits, ntasks, protocol, sockfd[0]);
//cout << "Performing private set-intersection between " << nelements << " and " << pnelements << " element sets" << endl;
if(detailed_timings) {
@ -62,7 +67,9 @@ int32_t psi_demonstrator(int32_t argc, char** argv) {
&crypto, sockfd.data(), ntasks);
break;
case TTP:
///ttppsi(role, nelements, elebytelen, elements, &intersection, &crypto, sockfd.data(), nclients, cardinality); break;
intersect_size = ttppsi(role, nelements, elebytelens, elements, &intersection, &res_bytelens,
&crypto, sockfd.data(), ntasks);
break;
case DH_ECC:
intersect_size = dhpsi(role, nelements, pnelements, elebytelens, elements, &intersection, &res_bytelens, &crypto,
sockfd.data(), ntasks);

View File

@ -73,8 +73,7 @@ uint32_t naivepsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx
ectx.eles.nelements = neles;
ectx.eles.output = hashes;
ectx.eles.perm = perm;
ectx.hctx.symcrypt = crypt_env;
ectx.sctx.symcrypt = crypt_env;
run_task(ntasks, ectx, hash);

View File

@ -71,7 +71,7 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c
ectx.eles.nelements = neles;
ectx.eles.outbytelen = hash_bytes;
ectx.eles.perm = perm;
ectx.hctx.symcrypt = crypt_env;
ectx.sctx.symcrypt = crypt_env;
#ifdef DEBUG
@ -86,14 +86,14 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c
ectx.eles.outbytelen = fe_bytes;
ectx.eles.output = encrypted_eles;
ectx.eles.hasvarbytelen = false;
ectx.ectx.field = field;
ectx.ectx.exponent = exponent;
ectx.ectx.sample = true;
ectx.actx.field = field;
ectx.actx.exponent = exponent;
ectx.actx.sample = true;
#ifdef DEBUG
cout << "Hash and encrypting my elements" << endl;
#endif
run_task(ntasks, ectx, encrypt);
run_task(ntasks, ectx, asym_encrypt);
peles = (uint8_t*) malloc(sizeof(uint8_t) * pneles * fe_bytes);
@ -110,13 +110,13 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c
ectx.eles.fixedbytelen = fe_bytes;
ectx.eles.outbytelen = fe_bytes;
ectx.eles.hasvarbytelen = false;
ectx.ectx.exponent = exponent;
ectx.ectx.sample = false;
ectx.actx.exponent = exponent;
ectx.actx.sample = false;
#ifdef DEBUG
cout << "Encrypting partners elements" << endl;
#endif
run_task(ntasks, ectx, encrypt);
run_task(ntasks, ectx, asym_encrypt);
/* if only the cardinality should be computed, permute the elements randomly again. Otherwise don't permute */
if(cardinality) {
@ -136,7 +136,7 @@ uint32_t dhpsi(role_type role, uint32_t neles, uint32_t pneles, task_ctx ectx, c
ectx.eles.outbytelen = hash_bytes;
ectx.eles.hasvarbytelen = false;
ectx.eles.perm = cardinality_perm;
ectx.hctx.symcrypt = crypt_env;
ectx.sctx.symcrypt = crypt_env;
#ifdef DEBUG
cout << "Hashing elements" << endl;

View File

@ -6,8 +6,8 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) {
CSocket* sockfds = socket;//(CSocket*) malloc(sizeof(CSocket) * nclients);
uint32_t* neles = (uint32_t*) malloc(sizeof(uint32_t) * nclients);
uint8_t** csets = (uint8_t**) malloc(sizeof(uint8_t*) * nclients);
uint8_t* intersect;
uint32_t temp, maskbytelen, intersectsize, minset, i;
uint32_t temp, maskbytelen, intersectsize, minset, i, j;
CBitVector* intersection = new CBitVector[nclients];
#ifndef BATCH
cout << "Connections with all " << nclients << " clients established" << endl;
@ -19,16 +19,17 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) {
sockfds[i].Receive(&temp, sizeof(uint32_t));
if(i == 0) { maskbytelen = temp; minset = neles[i];}
if(neles[i] < minset) minset = neles[i];
assert(maskbytelen == temp);
#ifndef BATCH
cout << "Client " << i << " holds " << neles[i] << " elements of length " << (temp * 8) << "-bit" << endl;
#endif
intersection[i].ResizeinBytes(ceil_divide(neles[i], 8));
intersection[i].Reset();
assert(maskbytelen == temp);
}
#ifndef BATCH
cout <<"Receiving the client's elements" << endl;
#endif
/* Allocate sufficient size for the intersecting elements */
intersect = (uint8_t*) malloc(sizeof(uint8_t*) * minset * maskbytelen);
/* Receive the permuted and masked sets of all clients */
for(i = 0; i < nclients; i++) {
@ -40,7 +41,10 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) {
cout << "Computing intersection for the clients" << endl;
#endif
/* Compute Intersection */
intersectsize = compute_intersection(nclients, neles, csets, intersect, maskbytelen);
intersectsize = compute_intersection(nclients, neles, csets, intersection, maskbytelen);
/* Enter at which position an intersection was found */
#ifndef BATCH
cout << "sending all " << intersectsize << " intersecting elements to the clients" << endl;
#endif
@ -48,7 +52,7 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) {
for(i = 0; i < nclients; i++) {
sockfds[i].Send(&intersectsize, sizeof(uint32_t));
if(!cardinality)
sockfds[i].Send(intersect, intersectsize * maskbytelen);
sockfds[i].Send(intersection[i].GetArr(), ceil_divide(neles[i], 8));
}
/* Cleanup */
@ -60,7 +64,7 @@ void server_routine(uint32_t nclients, CSocket* socket, bool cardinality) {
* for the n-party case a BF-based approach makes more sense.
*/
//TODO currently only works for 128 bit masks
uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, uint8_t* intersect, uint32_t entrybytelen) {
uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, CBitVector* intersection, uint32_t entrybytelen) {
// Create the GHashTable
GHashTable *map = NULL, *tmpmap = NULL;
GHashTableIter iter;
@ -72,9 +76,10 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset
NULL // cleanup value
);
uint32_t i, j, intersectsize, ctr = 0;
uint64_t* tmpval = (uint64_t*) malloc(sizeof(uint64_t));
uint32_t i, j, intersectsize, ctr = 0, k;
uint64_t* tmpval;
uint64_t* tmpkey = (uint64_t*) malloc(sizeof(uint64_t));
uint64_t* query;
#ifndef BATCH
cout << "Inserting the items into the hash table " << endl;
#endif
@ -83,7 +88,10 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset
#ifdef DEBUG
cout << "Inserted item: " << (hex) << ((uint64_t*) csets[0])[2*i] << " "<< ((uint64_t*) csets[0])[2*i+1] << (dec) << endl;
#endif
g_hash_table_insert(map,(void*) &((uint64_t*)csets[0])[2*i], &(((uint64_t*)csets[0])[2*i+1]));
tmpval = (uint64_t*) malloc(2*sizeof(uint64_t));
tmpval[0] = (((uint64_t*)csets[0])[2*i+1]);
tmpval[1] = i;
g_hash_table_insert(map,(void*) &((uint64_t*)csets[0])[2*i], tmpval);//&(((uint64_t*)csets[0])[2*i+1]));
}
#ifdef DEBUG
g_hash_table_foreach( map, printKeyValue, NULL );
@ -108,11 +116,17 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset
cout << "Checking for Key: " << (hex) << ((uint64_t*) csets[i])[2*j] << " "<< ((uint64_t*) csets[i])[2*j+1] << (dec) << endl;
#endif
if(g_hash_table_lookup_extended(map, (void*) &(((uint64_t*)csets[i])[2*j]),
NULL, (void**) &tmpval) && (*tmpval == ((uint64_t*)csets[i])[2*j+1])) {
NULL, (void**) &query) && (*query == ((uint64_t*)csets[i])[2*j+1])) {
#ifdef DEBUG
cout << "Key was found" << endl;
#endif
g_hash_table_insert(tmpmap,(void*) &(((uint64_t*)csets[i])[2*j]),&(((uint64_t*)csets[i])[2*j+1]));
tmpval = (uint64_t*) malloc((i+2)*sizeof(uint64_t));
tmpval[0] = (((uint64_t*)csets[i])[2*j+1]);
for(k = 1; k < i+1; k++) {
tmpval[k] = query[k];
}
tmpval[i+1] = j;
g_hash_table_insert(tmpmap,(void*) &(((uint64_t*)csets[i])[2*j]), tmpval);//&(((uint64_t*)csets[i])[2*j+1]));
} else {
#ifdef DEBUG
@ -143,8 +157,9 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset
#ifdef DEBUG
cout << (hex) << tmpkey[0] << " " << tmpval[0] << (dec)<< endl;
#endif
((uint64_t*) intersect)[ctr++] = tmpkey[0];
((uint64_t*) intersect)[ctr++] = tmpval[0];
for(i = 0; i < nclients; i++) {
intersection[i].SetBit(tmpval[i+1], 1);
}
}
gettimeofday(&end, NULL);
#ifdef TIMING
@ -162,69 +177,54 @@ uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** cset
}
uint32_t client_routine(uint32_t neles, uint32_t elebytelen, uint8_t* elements,
uint8_t** result, crypto* crypt, CSocket* socket, bool cardinality) {
uint32_t maskbytelen = 16, intersectsize, i, j;
uint8_t* masks = (uint8_t*) malloc(sizeof(uint8_t) * neles * maskbytelen);
uint8_t* intersect = (uint8_t*) malloc(sizeof(uint8_t) * neles * maskbytelen);
uint32_t* perm;
uint32_t* invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles);
uint32_t* tmpval = (uint32_t*) malloc(sizeof(uint32_t));
GHashTable *map;
//crypto crypto(symsecbits, (uint8_t*) const_seed);
// cout << "Starting client with " << neles << " elements of " << (8*elebytelen) << "-bit length with server "
// << address << ":" << port << endl;
CSocket* sockfd = socket;
//connect(address, port, sockfd);
sockfd->Send((uint8_t*) &neles, sizeof(uint32_t));
sockfd->Send((uint8_t*) &maskbytelen, sizeof(uint32_t));
uint32_t client_routine(uint32_t neles, task_ctx ectx, uint32_t* matches,
crypto* crypt_env, CSocket* socket, uint32_t ntasks, bool cardinality) {
uint32_t maskbytelen, intersectsize, i, matchctr;
perm = mask_and_permute_elements(neles, elebytelen, elements, maskbytelen, masks, crypt->get_seclvl().symbits, crypt);
uint8_t* masks;
uint32_t *perm, *invperm;
CBitVector inIntersection(neles);
sockfd->Send(masks, maskbytelen * neles);
//TODO works only fine for equally sized sets, if one set is bigger than the other, this will fail!
maskbytelen = 16;//ceil_divide(crypt_env->get_seclvl().statbits + 2*ceil_log2(neles), 8);
if(!cardinality) {
for(i = 0; i < neles; i++) {
invperm[perm[i]] = i;
}
masks = (uint8_t*) malloc(sizeof(uint8_t) * neles * maskbytelen);
perm = (uint32_t*) malloc(sizeof(uint32_t) * neles);
invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles);
map= g_hash_table_new_full(g_int64_hash, g_int64_equal, NULL, NULL);
for(i = 0; i < neles; i++) {
g_hash_table_insert(map,(void*) &((uint64_t*)masks)[2*i], &(invperm[i]));
}
/* Generate the random permutation the elements */
crypt_env->gen_rnd_perm(perm, neles);
socket->Send((uint8_t*) &neles, sizeof(uint32_t));
socket->Send((uint8_t*) &maskbytelen, sizeof(uint32_t));
ectx.eles.outbytelen = maskbytelen,
ectx.eles.nelements = neles;
ectx.eles.output = masks;
ectx.eles.perm = perm;
ectx.sctx.symcrypt = crypt_env;
ectx.sctx.keydata = (uint8_t*) const_seed;
run_task(ntasks, ectx, hash);
socket->Send(masks, maskbytelen * neles);
socket->Receive(&intersectsize, sizeof(uint32_t));
for(i = 0; i < neles; i++) {
invperm[perm[i]] = i;
}
sockfd->Receive(&intersectsize, sizeof(uint32_t));
if(!cardinality) {
sockfd->Receive(intersect, maskbytelen * intersectsize);
socket->Receive(inIntersection.GetArr(), ceil_divide(neles, 8));
#ifdef DEBUG
cout << "The intersection contains " << intersectsize << " elements: " << endl;
for(i = 0; i < intersectsize; i++) {
cout << (hex) << ((uint64_t*)intersect)[2*i] << " " << ((uint64_t*)intersect)[2*i+1] << (dec) << endl;
for(i = 0, matchctr = 0; i < neles; i++) {
if(inIntersection.GetBit(i)) {
matches[matchctr] = invperm[i];
matchctr++;
}
}
#endif
*result = (uint8_t*) malloc(elebytelen * intersectsize);
//uint8_t* tmpbuf = (uint8_t*) malloc(maskbytelen);
for(i = 0; i < intersectsize; i++) {
g_hash_table_lookup_extended(map, (void*) &(((uint64_t*)intersect)[2*i]), NULL, (void**) &tmpval);
memcpy((*result) + i * elebytelen, elements + tmpval[0] * elebytelen, elebytelen);
//crypto.decrypt(tmpbuf, intersect+i*maskbytelen, maskbytelen);
//memcpy((*result) + i * elebytelen, tmpbuf, elebytelen);
#ifdef DEBUG
cout << ((uint32_t*) elements)[tmpval[0]] << ", ";
#endif
}
#ifdef DEBUG
cout << endl;
#endif
}
free(perm);
@ -233,51 +233,52 @@ uint32_t client_routine(uint32_t neles, uint32_t elebytelen, uint8_t* elements,
return intersectsize;
}
void printKeyValue( gpointer key, gpointer value, gpointer userData ) {
uint64_t realKey = *((uint64_t*)key);
uint64_t realValue = *((uint64_t*)value);
cout << (hex) << realKey << ": " << realValue << (dec) << endl;
return;
}
uint32_t* mask_and_permute_elements(uint32_t neles, uint32_t elebytelen, uint8_t*
elements, uint32_t maskbytelen, uint8_t* masks, uint32_t symsecbits, crypto* crypto) {
uint32_t* perm = (uint32_t*) malloc(sizeof(uint32_t) * neles);
uint8_t* maskpermptr;
uint32_t i;
//Get random permutation
crypto->gen_rnd_perm(perm, neles);
//crypto->seed_aes_enc(client_psk);
//Hash and permute all elements
for(i = 0; i < neles; i++) {
//cout << "Performing encryption for " << i << "-th element " << ((uint32_t*) elements)[i] << ": ";
maskpermptr = masks + perm[i] * maskbytelen;
crypto->hash(maskpermptr, maskbytelen, elements + i * elebytelen, elebytelen);
//crypto->encrypt(maskpermptr, elements+i*elebytelen, elebytelen);
//cout <<(hex)<< ((uint64_t*) maskpermptr)[0] << ((uint64_t*) maskpermptr)[1] << (dec) << endl;
#ifdef DEBUG
cout << "Resulting hash for element " << ((uint32_t*)elements)[i] << ": " << (hex) << ((uint64_t*) maskpermptr)[0] <<
" " << ((uint64_t*) maskpermptr)[1] << (dec) << endl;
#endif
}
//free(perm);
return perm;
}
uint32_t ttppsi(role_type role, uint32_t neles, uint32_t elebytelen, uint8_t* elements,
uint8_t** intersection, crypto* crypt, CSocket* sockets, uint32_t nclients, bool cardinality) {
uint8_t** result, crypto* crypt, CSocket* sockets, uint32_t ntasks, uint32_t nclients, bool cardinality) {
if(role == 0) { //Start the server
//TODO maybe rerun infinitely
server_routine(nclients, sockets, cardinality);
return 0;
} else { //Start clients
return client_routine(neles, elebytelen, elements, intersection, crypt, sockets, cardinality);
task_ctx ectx;
ectx.eles.input1d = elements;
ectx.eles.fixedbytelen = elebytelen;
ectx.eles.hasvarbytelen = false;
uint32_t* matches = (uint32_t*) malloc(sizeof(uint32_t) * neles);
uint32_t intersect_size = client_routine(neles, ectx, matches, crypt, sockets, ntasks, cardinality);
create_result_from_matches_fixed_bitlen(result, elebytelen, elements, matches, intersect_size);
free(matches);
return intersect_size;
}
}
uint32_t ttppsi(role_type role, uint32_t neles, uint32_t* elebytelens, uint8_t** elements,
uint8_t*** result, uint32_t** resbytelens, crypto* crypt, CSocket* sockets,
uint32_t ntasks, uint32_t nclients, bool cardinality) {
if(role == 0) { //Start the server
//TODO maybe rerun infinitely
server_routine(nclients, sockets, cardinality);
return 0;
} else { //Start clients
task_ctx ectx;
ectx.eles.input2d = elements;
ectx.eles.varbytelens = elebytelens;
ectx.eles.hasvarbytelen = true;
uint32_t* matches = (uint32_t*) malloc(sizeof(uint32_t) * neles);
uint32_t intersect_size = client_routine(neles, ectx, matches, crypt, sockets, ntasks, cardinality);
create_result_from_matches_var_bitlen(result, resbytelens, elebytelens, elements, matches, intersect_size);
free(matches);
return intersect_size;
}
}

View File

@ -1,5 +1,5 @@
/*
* shpsi.h
* sapsi.h
*
* Created on: Jul 1, 2014
* Author: mzohner
@ -13,12 +13,18 @@
#include "../util/socket.h"
#include "../util/typedefs.h"
#include "../util/connection.h"
#include "../util/helpers.h"
#include "../util/cbitvector.h"
/* start both roles*/
uint32_t ttppsi(role_type role, uint32_t neles, uint32_t elebytelen, uint8_t* elements,
uint8_t** intersection, crypto* crypt, CSocket* socket, uint32_t nclients = 0, bool cardinality=false);
uint8_t** intersection, crypto* crypt, CSocket* socket, uint32_t ntasks, uint32_t nclients = 2, bool cardinality=false);
uint32_t ttppsi(role_type role, uint32_t neles, uint32_t* elebytelens, uint8_t** elements,
uint8_t*** result, uint32_t** resbytelens, crypto* crypt, CSocket* sockets,
uint32_t ntasks, uint32_t nclients = 2, bool cardinality = false);
/*
* Params:
@ -30,8 +36,8 @@ uint32_t ttppsi(role_type role, uint32_t neles, uint32_t elebytelen, uint8_t* el
* port: port that the server is listening on
* return: number of intersecting elements
*/
uint32_t client_routine(uint32_t neles, uint32_t elebytelen, uint8_t* elements,
uint8_t** intersection, crypto* crypt, CSocket* socket, bool cardinality);
uint32_t client_routine(uint32_t neles, task_ctx ectx, uint32_t* matches, crypto* crypt,
CSocket* socket, uint32_t ntasks, bool cardinality);
/*
* Mask and permute the elements using the pre-shared key
@ -48,7 +54,7 @@ uint32_t* mask_and_permute_elements(uint32_t neles, uint32_t elebytelen, uint8_t
*/
void server_routine(uint32_t nclients, CSocket* socket, bool cardinality);
uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, uint8_t* intersect, uint32_t entrybytelen);
uint32_t compute_intersection(uint32_t nclients, uint32_t* neles, uint8_t** csets, CBitVector* intersection, uint32_t entrybytelen);
void printKeyValue( gpointer key, gpointer value, gpointer userData );

View File

@ -120,13 +120,13 @@ void crypto::gen_rnd_uniform(uint8_t* resbuf, uint64_t mod) {
void crypto::encrypt(AES_KEY_CTX* enc_key, uint8_t* resbuf, uint8_t* inbuf, uint32_t ninbytes) {
int32_t dummy;
EVP_EncryptUpdate(enc_key, resbuf, &dummy, inbuf, ninbytes);
//EVP_EncryptFinal_ex(enc_key, resbuf, &dummy);
EVP_EncryptFinal_ex(enc_key, resbuf, &dummy);
}
void crypto::decrypt(AES_KEY_CTX* dec_key, uint8_t* resbuf, uint8_t* inbuf, uint32_t ninbytes) {
int32_t dummy;
//cout << "inbuf = " << (hex) << ((uint64_t*) inbuf)[0] << ((uint64_t*) inbuf)[1] << (dec) << endl;
EVP_DecryptUpdate(dec_key, resbuf, &dummy, inbuf, ninbytes);
//EVP_DecryptFinal_ex(dec_key, resbuf, &dummy);
EVP_DecryptFinal_ex(dec_key, resbuf, &dummy);
//cout << "outbuf = " << (hex) << ((uint64_t*) resbuf)[0] << ((uint64_t*) resbuf)[1] << (dec) << " (" << dummy << ")" << endl;
}
@ -212,6 +212,10 @@ void crypto::hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t
hash_routine(resbuf, noutbytes, inbuf, ninbytes, sha_hash_buf);
}
void crypto::hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes, uint8_t* tmpbuf) {
hash_routine(resbuf, noutbytes, inbuf, ninbytes, tmpbuf);
}
//A fixed-key hashing scheme that uses AES, should not be used for real hashing, hashes to AES_BYTES bytes
void crypto::fixed_key_aes_hash(AES_KEY_CTX* aes_key, uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes) {

View File

@ -62,6 +62,7 @@ public:
//Hash routines
void hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes);
void hash(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes, uint8_t* tmpbuf);
void hash_ctr(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes, uint32_t ctr);
void fixed_key_aes_hash(AES_KEY_CTX* aes_key, uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes);
void fixed_key_aes_hash_ctr(uint8_t* resbuf, uint32_t noutbytes, uint8_t* inbuf, uint32_t ninbytes);

View File

@ -31,12 +31,12 @@ struct element_ctx {
bool hasvarbytelen;
};
struct hash_ctx {
struct sym_ctx {
crypto* symcrypt;
uint8_t* keydata;
};
struct encrypt_ctx {
struct asym_ctx {
num* exponent;
pk_crypto* field;
bool sample;
@ -45,8 +45,8 @@ struct encrypt_ctx {
struct task_ctx {
element_ctx eles;
union {
hash_ctx hctx;
encrypt_ctx ectx;
sym_ctx sctx;
asym_ctx actx;
};
};
@ -108,20 +108,20 @@ static void create_result_from_matches_fixed_bitlen(uint8_t** result, uint32_t i
}
}
static void *encrypt(void* context) {
static void *asym_encrypt(void* context) {
#ifdef DEBUG
cout << "Encryption task started" << endl;
#endif
pk_crypto* field = ((task_ctx*) context)->ectx.field;
pk_crypto* field = ((task_ctx*) context)->actx.field;
element_ctx electx = ((task_ctx*) context)->eles;
num* e = ((task_ctx*) context)->ectx.exponent;
num* e = ((task_ctx*) context)->actx.exponent;
fe* tmpfe = field->get_fe();
uint8_t *inptr=electx.input1d, *outptr=electx.output;
uint32_t i;
for(i = 0; i < electx.nelements; i++, inptr+=electx.fixedbytelen, outptr+=electx.outbytelen) {
if(((task_ctx*) context)->ectx.sample) {
if(((task_ctx*) context)->actx.sample) {
tmpfe->sample_fe_from_bytes(inptr, electx.fixedbytelen);
//cout << "Mapped " << ((uint32_t*) inptr)[0] << " to ";
} else {
@ -135,29 +135,71 @@ static void *encrypt(void* context) {
return 0;
}
static void *hash(void* context) {
static void *sym_encrypt(void* context) {
#ifdef DEBUG
cout << "Hashing thread started" << endl;
#endif
hash_ctx hdata = ((task_ctx*) context)->hctx;
sym_ctx hdata = ((task_ctx*) context)->sctx;
element_ctx electx = ((task_ctx*) context)->eles;
crypto* crypt_env = hdata.symcrypt;
AES_KEY_CTX aes_key;
//cout << "initializing key" << endl;
crypt_env->init_aes_key(&aes_key, hdata.keydata);
//cout << "initialized key" << endl;
uint8_t* aes_buf = (uint8_t*) malloc(AES_BYTES);
uint32_t* perm = electx.perm;
uint32_t i;
if(electx.hasvarbytelen) {
uint8_t **inptr = electx.input2d;
for(i = electx.startelement; i < electx.endelement; i++) {
crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr[i], electx.varbytelens[i]);
//crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr[i], electx.varbytelens[i]);
//cout << "encrypting i = " << i << ", perm = " << perm [i] << ", outbytelen = " << electx.outbytelen << endl;
crypt_env->encrypt(&aes_key, aes_buf, inptr[i], electx.varbytelens[i]);
memcpy(electx.output+perm[i]*electx.outbytelen, aes_buf, electx.outbytelen);
}
} else {
uint8_t *inptr = electx.input1d;
for(i = electx.startelement; i < electx.endelement; i++, inptr+=electx.fixedbytelen) {
crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr, electx.fixedbytelen);
//crypt_env->hash(&aes_key, electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr, electx.fixedbytelen);
crypt_env->encrypt(&aes_key, aes_buf, inptr, electx.fixedbytelen);
memcpy(electx.output+perm[i]*electx.outbytelen, aes_buf, electx.outbytelen);
}
}
//cout << "Returning" << endl;
//free(aes_buf);
return 0;
}
static void *hash(void* context) {
#ifdef DEBUG
cout << "Hashing thread started" << endl;
#endif
sym_ctx hdata = ((task_ctx*) context)->sctx;
element_ctx electx = ((task_ctx*) context)->eles;
crypto* crypt_env = hdata.symcrypt;
uint32_t* perm = electx.perm;
uint32_t i;
uint8_t* tmphashbuf = (uint8_t*) malloc(crypt_env->get_hash_bytes());
if(electx.hasvarbytelen) {
uint8_t **inptr = electx.input2d;
for(i = electx.startelement; i < electx.endelement; i++) {
crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr[i], electx.varbytelens[i], tmphashbuf);
}
} else {
uint8_t *inptr = electx.input1d;
for(i = electx.startelement; i < electx.endelement; i++, inptr+=electx.fixedbytelen) {
crypt_env->hash(electx.output+perm[i]*electx.outbytelen, electx.outbytelen, inptr, electx.fixedbytelen, tmphashbuf);
}
}
free(tmphashbuf);
return 0;
}
@ -198,8 +240,6 @@ static void run_task(uint32_t nthreads, task_ctx context, void* (*func)(void*) )
neles_cur = min(context.eles.nelements - electr, neles_thread);
memcpy(contexts + i, &context, sizeof(task_ctx));
contexts[i].eles.nelements = neles_cur;
//contexts[i].eles.input = context.eles.input + (context.eles.inbytelen * electr);
//contexts[i].eles.output = context.eles.output + (context.eles.outbytelen * electr);
contexts[i].eles.startelement = electr;
contexts[i].eles.endelement = electr + neles_cur;
electr += neles_cur;
@ -228,7 +268,6 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas
uint32_t hashbytelen, uint32_t* perm, uint32_t* matches) {
uint32_t* invperm = (uint32_t*) malloc(sizeof(uint32_t) * neles);
//uint32_t* matches = (uint32_t*) malloc(sizeof(uint32_t) * neles);
uint64_t* tmpval;
uint32_t size_intersect, i, intersect_ctr;
@ -236,17 +275,12 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas
for(i = 0; i < neles; i++) {
invperm[perm[i]] = i;
}
//cout << "My number of elements. " << neles << ", partner number of elements: " << pneles << ", maskbytelen: " << hashbytelen << endl;
GHashTable *map= g_hash_table_new_full(g_int64_hash, g_int64_equal, NULL, NULL);
for(i = 0; i < neles; i++) {
g_hash_table_insert(map,(void*) ((uint64_t*) &(hashes[i*hashbytelen])), &(invperm[i]));
}
//for(i = 0; i < pneles; i++) {
// ((uint64_t*) &(phashes[i*hashbytelen]))[0]++;
//}
for(i = 0, intersect_ctr = 0; i < pneles; i++) {
if(g_hash_table_lookup_extended(map, (void*) ((uint64_t*) &(phashes[i*hashbytelen])),
@ -259,14 +293,7 @@ static uint32_t find_intersection(uint8_t* hashes, uint32_t neles, uint8_t* phas
size_intersect = intersect_ctr;
//result = (uint8_t**) malloc(sizeof(uint8_t*));
//(*result) = (uint8_t*) malloc(sizeof(uint8_t) * size_intersect * elebytelen);
//for(i = 0; i < size_intersect; i++) {
// memcpy((*result) + i * elebytelen, elements + matches[i] * elebytelen, elebytelen);
//}
free(invperm);
//free(matches);
return size_intersect;
}