[go: up one dir, main page]

Skip to content

Commit

Permalink
solver, math: fix artificial variable expr terms sharing the same mem…
Browse files Browse the repository at this point in the history
…ory as the term being optimized away, add clone methods to Expr and Constraint, speed up solveFor() by not performing any operations if the coefficient is 1
  • Loading branch information
lithdew committed May 30, 2020
1 parent 3fba220 commit 1f9f4b9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
17 changes: 16 additions & 1 deletion math.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func next(typ SymbolKind) Symbol {
}

func (sym Symbol) Kind() SymbolKind { return SymbolKind(sym >> 62) }
func (sym Symbol) Zero() bool { return sym&0x3fffffffffffffff == 0 }
func (sym Symbol) Zero() bool { return sym == zero }
func (sym Symbol) Restricted() bool { return !sym.Zero() && sym.Kind().Restricted() }
func (sym Symbol) External() bool { return !sym.Zero() && sym.Kind() == External }
func (sym Symbol) Slack() bool { return !sym.Zero() && sym.Kind() == Slack }
Expand Down Expand Up @@ -84,6 +84,11 @@ func NewConstraint(op Op, constant float64, terms ...Term) Constraint {
return Constraint{op: op, expr: NewExpr(constant, terms...)}
}

func (c Constraint) clone() Constraint {
res := Constraint{op: c.op, expr: c.expr.clone()}
return res
}

type Term struct {
coeff float64
id Symbol
Expand All @@ -98,6 +103,12 @@ func NewExpr(constant float64, terms ...Term) Expr {
return Expr{constant: constant, terms: terms}
}

func (c Expr) clone() Expr {
res := Expr{constant: c.constant, terms: make([]Term, len(c.terms))}
copy(res.terms, c.terms)
return res
}

func (c Expr) find(id Symbol) int {
for i := 0; i < len(c.terms); i++ {
if c.terms[i].id == id {
Expand Down Expand Up @@ -152,6 +163,10 @@ func (c *Expr) solveFor(id Symbol) {
coeff := -1.0 / c.terms[idx].coeff
c.delete(idx)

if coeff == 1.0 {
return
}

c.constant *= coeff
for i := 0; i < len(c.terms); i++ {
c.terms[i].coeff *= coeff
Expand Down
20 changes: 8 additions & 12 deletions solver.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,9 @@ func (s *Solver) substitute(id Symbol, expr Expr) {
row := s.tabs[symbol]
row.expr.substitute(id, expr)
s.tabs[symbol] = row

if symbol.External() || row.expr.constant >= 0.0 {
continue
}

s.infeasible = append(s.infeasible, symbol)
}
s.objective.substitute(id, expr)
Expand All @@ -360,11 +358,10 @@ func (s *Solver) optimizeAgainst(objective *Expr) error {
exit := zero

for _, term := range objective.terms {
if term.id.Dummy() || term.coeff >= 0.0 {
continue
if !term.id.Dummy() && term.coeff < 0.0 {
entry = term.id
break
}
entry = term.id
break
}
if entry.Zero() {
return nil
Expand Down Expand Up @@ -403,8 +400,8 @@ func (s *Solver) optimizeAgainst(objective *Expr) error {
func (s *Solver) augmentArtificialVariable(row Constraint) error {
art := next(Slack)

s.tabs[art] = row
s.artificial = row.expr
s.tabs[art] = row.clone()
s.artificial = row.expr.clone()

err := s.optimizeAgainst(&s.artificial)
if err != nil {
Expand All @@ -424,11 +421,10 @@ func (s *Solver) augmentArtificialVariable(row Constraint) error {

entry := zero
for _, term := range artificial.expr.terms {
if !term.id.Restricted() {
continue
if term.id.Restricted() {
entry = term.id
break
}
entry = term.id
break
}
if entry.Zero() {
return errors.New("unsatisfiable")
Expand Down

0 comments on commit 1f9f4b9

Please sign in to comment.