From c9beab26fe04a522fc8c7f3603387823cea86200 Mon Sep 17 00:00:00 2001 From: David Florness Date: Mon, 8 Jun 2020 20:20:15 -0600 Subject: [PATCH] Try to cleanup locks --- voter.go | 85 +++++++++++++++++++++++--------------------------------- 1 file changed, 35 insertions(+), 50 deletions(-) diff --git a/voter.go b/voter.go index edd1b10..0e79d7f 100644 --- a/voter.go +++ b/voter.go @@ -21,8 +21,10 @@ import ( ) type Voter struct { - sync.RWMutex sum *big.Int + // may get more than 1 eval from peer; doesn't need to be RW because we + // never serve it to peers + inputMu sync.Mutex input *big.Int output *big.Int addrInfo peer.AddrInfo @@ -31,9 +33,16 @@ type Voter struct { type Me struct { Voter host.Host - ctx context.Context - kdht *dht.IpfsDHT - poly *Poly + ctx context.Context + kdht *dht.IpfsDHT + + poly *Poly + + // mutexs only used for atomicity; atomicity.Value sucks because we lose + // type safety with interface{} + polyMu sync.RWMutex // poly is computed after ballot; don't want R/W data races + sumMu sync.RWMutex // sum is computed in a loop; don + inputMu sync.RWMutex // TODO remove by generating input right away } type Election struct { @@ -110,8 +119,8 @@ func handleCmd(cmd string, rw *bufio.ReadWriter, stream network.Stream) { }, } case "eval": // peer is giving their input and requesting output from our poly - me.RLock() - defer me.RUnlock() + me.polyMu.RLock() + defer me.polyMu.RUnlock() if me.poly == nil { logger.Warning("peer attempted to eval before we had our poly:", stream.Conn().RemotePeer()) @@ -132,19 +141,19 @@ func handleCmd(cmd string, rw *bufio.ReadWriter, stream network.Stream) { } peer, exists := election.remoteVoters[stream.Conn().RemotePeer()] if !exists { - logger.Warning("receiving eval command from unrecognized peer") + logger.Warning("received eval command from unrecognized peer") return } - peer.Lock() + peer.inputMu.Lock() + defer peer.inputMu.Unlock() peer.input = new(big.Int).SetBytes(inputBytes) logger.Infof("%s input: %s", peer.addrInfo.ID, peer.input) - peer.Unlock() output := me.poly.Eval(peer.input) rw.WriteString(base58.Encode(output.Bytes())) rw.Flush() case "sum": - me.RLock() - defer me.RUnlock() + me.sumMu.RLock() + defer me.sumMu.RUnlock() if me.sum == nil { logger.Info("peer attempted to fetch sum "+ "before we computed it:", stream.Conn().RemotePeer()) @@ -153,8 +162,8 @@ func handleCmd(cmd string, rw *bufio.ReadWriter, stream network.Stream) { rw.WriteString(base58.Encode(me.sum.Bytes())) rw.Flush() case "input": - me.RLock() - defer me.RUnlock() + me.inputMu.RLock() + defer me.inputMu.RLock() if me.input == nil { logger.Info("peer attempted to fetch input "+ "before we had one:", stream.Conn().RemotePeer()) @@ -229,7 +238,7 @@ func findPeers(closeElection <-chan int) { logger.Info("done finding peers") } -func (voter *Voter) fetchNumber(done chan<- *big.Int, cmd string, args ...string) { +func (voter *Voter) fetchNumber(cmd string, args ...string) *big.Int { printErr := func(err error, msg string) { logger.Errorf("%s: %s fetcing `%s'; retrying in 2 seconds", voter.addrInfo.ID, msg, cmd) @@ -276,12 +285,14 @@ retry: printErr(err, "couldn't base58-decode contents from stream") goto retry } - done <- new(big.Int).SetBytes(retBytes) + return new(big.Int).SetBytes(retBytes) } func startVoting() { var err error + me.inputMu.Lock() me.input, err = RandomBigInt(128, false) + me.inputMu.Unlock() if err != nil { panic(err) } @@ -290,55 +301,37 @@ func startVoting() { ballot := vote(candidates) // no +1 since we want degree k-1 where k is total number of voters + me.polyMu.Lock() me.poly = NewRandomPoly(uint(len(election.remoteVoters)), 1024, ballot) + me.polyMu.Unlock() // get outputs var wg sync.WaitGroup for _, voter := range election.remoteVoters { wg.Add(1) go func(voter *Voter) { - done := make(chan *big.Int, 1) - me.RLock() - go voter.fetchNumber(done, "eval", - base58.Encode(me.input.Bytes())) - me.RUnlock() - output := <- done - voter.Lock() - voter.output = output - voter.Unlock() - voter.RLock() - logger.Infof("voter %s output: %s", - voter.addrInfo.ID, voter.output) - voter.RUnlock() + voter.output = voter.fetchNumber("eval", base58.Encode(me.input.Bytes())) + logger.Infof("voter %s output: %s", voter.addrInfo.ID, voter.output) wg.Done() }(voter) } wg.Wait() // calculate sum - me.Lock() + me.sumMu.Lock() me.sum = me.poly.Eval(me.input) for _, voter := range election.remoteVoters { me.sum.Add(me.sum, voter.output) } - me.Unlock() + me.sumMu.Unlock() // get sums for _, voter := range election.remoteVoters { wg.Add(1) go func(voter *Voter) { - done := make(chan *big.Int, 1) - me.RLock() - go voter.fetchNumber(done, "sum", base58.Encode(me.sum.Bytes())) - me.RUnlock() - sum := <- done - voter.Lock() - voter.sum = sum - voter.Unlock() - voter.RLock() + voter.sum = voter.fetchNumber("sum") logger.Infof("voter %s sum: %s", voter.addrInfo.ID, voter.sum) - voter.RUnlock() wg.Done() }(voter) } @@ -346,18 +339,10 @@ func startVoting() { // ensure we have everyone's inputs for _, voter := range election.remoteVoters { - voter.RLock() - haveInput := voter.input != nil - voter.RUnlock() - if !haveInput { + if voter.input == nil { wg.Add(1) go func(voter *Voter) { - done := make(chan *big.Int, 1) - go voter.fetchNumber(done, "input") - input := <- done - voter.Lock() - voter.input = input - voter.Unlock() + voter.input = voter.fetchNumber("input") }(voter) } } -- 2.38.4