/*++
Copyright (c) 2007 Microsoft Corporation

Module Name:

    pull_quant.cpp

Abstract:

    Pull nested quantifiers.

Author:

    Leonardo (leonardo) 2008-01-20

Notes:

--*/
#include"pull_quant.h"
#include"var_subst.h"
#include"rewriter_def.h"
#include"ast_pp.h"

struct pull_quant::imp {
    
    struct rw_cfg : public default_rewriter_cfg {
        ast_manager & m_manager;
        shift_vars    m_shift;
        
        rw_cfg(ast_manager & m):
            m_manager(m), 
            m_shift(m) {
        }

        bool pull_quant1_core(func_decl * d, unsigned num_children, expr * const * children, expr_ref & result) {
            ptr_buffer<sort>  var_sorts;
            buffer<symbol>    var_names;
            symbol            qid;
            int               w = INT_MAX;
            
            // The input formula is in Skolem normal form...
            // So all children are forall (positive context) or exists (negative context).
            // Remark: (AND a1 ...) may be represented (NOT (OR (NOT a1) ...)))
            // So, when pulling a quantifier over a NOT, it becomes an exists.
            
            if (m_manager.is_not(d)) {
                SASSERT(num_children == 1);
                expr * child = children[0];
                if (is_quantifier(child)) {
                    quantifier * q = to_quantifier(child);
                    expr * body = q->get_expr();
                    result = m_manager.update_quantifier(q, !q->is_forall(), m_manager.mk_not(body));
                    return true;
                }
                else {
                    return false;
                }
            }
            
            bool found_quantifier = false;
            bool forall_children;
            
            for (unsigned i = 0; i < num_children; i++) {
                expr * child = children[i];
                if (is_quantifier(child)) {
                    
                    if (!found_quantifier) {
                        found_quantifier = true;
                        forall_children  = is_forall(child);
                    }
                    else {
                        // Since the initial formula was in SNF, all children must be EXISTS or FORALL.
                        SASSERT(forall_children == is_forall(child));
                    }
                    
                    quantifier * nested_q = to_quantifier(child);
                    if (var_sorts.empty()) {
                        // use the qid of one of the nested quantifiers.
                        qid = nested_q->get_qid();
                    }
                    w = std::min(w, nested_q->get_weight());
                    unsigned j = nested_q->get_num_decls();
                    while (j > 0) {
                        --j;
                        var_sorts.push_back(nested_q->get_decl_sort(j));
                        symbol s = nested_q->get_decl_name(j);
                        if (std::find(var_names.begin(), var_names.end(), s) != var_names.end())
                            var_names.push_back(m_manager.mk_fresh_var_name(s.is_numerical() ? 0 : s.bare_str()));
                        else
                            var_names.push_back(s);
                    }
                }
            }
            
            if (!var_sorts.empty()) {
                SASSERT(found_quantifier);
                // adjust the variable ids in formulas in new_children
                expr_ref_buffer   new_adjusted_children(m_manager);
                expr_ref          adjusted_child(m_manager);
                unsigned          num_decls = var_sorts.size();
                unsigned          shift_amount = 0;
                TRACE("pull_quant", tout << "Result num decls:" << num_decls << "\n";);
                for (unsigned i = 0; i < num_children; i++) {
                    expr * child = children[i];
                    if (!is_quantifier(child)) {
                        // increment the free variables in child by num_decls because
                        // child will be in the scope of num_decls bound variables.
                        m_shift(child, num_decls, adjusted_child);
                        TRACE("pull_quant", tout << "shifted by: " << num_decls << "\n" << 
                              mk_pp(child, m_manager) << "\n---->\n" << mk_pp(adjusted_child, m_manager) << "\n";);
                    }
                    else {
                        quantifier * nested_q = to_quantifier(child);
                        SASSERT(num_decls >= nested_q->get_num_decls());
                        // Assume nested_q is of the form 
                        // forall xs. P(xs, ys)
                        // where xs (ys) represents the set of bound (free) variables.
                        //
                        // - the index of the variables xs must be increased by shift_amount.
                        //   That is, the number of new bound variables that will precede the bound
                        //   variables xs.
                        //
                        // - the index of the variables ys must be increased by num_decls - nested_q->get_num_decls. 
                        //   That is, the total number of new bound variables that will be in the scope
                        //   of nested_q->get_expr().
                        m_shift(nested_q->get_expr(), 
                                nested_q->get_num_decls(),             // bound for shift1/shift2
                                num_decls - nested_q->get_num_decls(), // shift1  (shift by this ammount if var idx >= bound)
                                shift_amount,                          // shift2  (shift by this ammount if var idx < bound)
                                adjusted_child);
                        TRACE("pull_quant", tout << "shifted  bound: " << nested_q->get_num_decls() << " shift1: " << shift_amount <<
                              " shift2: " << (num_decls - nested_q->get_num_decls()) << "\n" << mk_pp(nested_q->get_expr(), m_manager) << 
                              "\n---->\n" << mk_pp(adjusted_child, m_manager) << "\n";);
                        shift_amount += nested_q->get_num_decls();
                    }
                    new_adjusted_children.push_back(adjusted_child);
                }
                
                // Remark: patterns are ignored.
                // This is ok, since this functor is used in one of the following cases:
                //
                // 1) Superposition calculus is being used, so the
                // patterns are useless.
                //
                // 2) No patterns were provided, and the functor is used
                // to increase the effectiveness of the pattern inference
                // procedure.
                //
                // 3) MBQI 
                std::reverse(var_sorts.begin(), var_sorts.end());
                std::reverse(var_names.begin(), var_names.end());
                result = m_manager.mk_quantifier(forall_children,
                                                 var_sorts.size(),
                                                 var_sorts.c_ptr(),
                                                 var_names.c_ptr(),
                                                 m_manager.mk_app(d, new_adjusted_children.size(), new_adjusted_children.c_ptr()),
                                                 w,
                                                 qid);
                return true;
            }
            else {
                SASSERT(!found_quantifier);
                return false;
            }
        }

        void pull_quant1(func_decl * d, unsigned num_children, expr * const * children, expr_ref & result) {
            if (!pull_quant1_core(d, num_children, children, result)) {
                result = m_manager.mk_app(d, num_children, children);
            }
        }


        void pull_quant1_core(quantifier * q, expr * new_expr, expr_ref & result) {
            // The original formula was in SNF, so the original quantifiers must be universal.
            SASSERT(is_forall(q));
            SASSERT(is_forall(new_expr));
            quantifier * nested_q = to_quantifier(new_expr);
            ptr_buffer<sort> var_sorts;
            buffer<symbol>   var_names;
            var_sorts.append(q->get_num_decls(), const_cast<sort**>(q->get_decl_sorts()));
            var_sorts.append(nested_q->get_num_decls(), const_cast<sort**>(nested_q->get_decl_sorts()));
            var_names.append(q->get_num_decls(), const_cast<symbol*>(q->get_decl_names()));
            var_names.append(nested_q->get_num_decls(), const_cast<symbol*>(nested_q->get_decl_names()));
            // Remark: patterns are ignored.
            // See comment in reduce1_app
            result = m_manager.mk_forall(var_sorts.size(),
                                         var_sorts.c_ptr(),
                                         var_names.c_ptr(),
                                         nested_q->get_expr(),
                                         std::min(q->get_weight(), nested_q->get_weight()),
                                         q->get_qid());
        }

        void pull_quant1(quantifier * q, expr * new_expr, expr_ref & result) {
            // The original formula was in SNF, so the original quantifiers must be universal.
            SASSERT(is_forall(q));
            if (is_forall(new_expr)) { 
                pull_quant1_core(q, new_expr, result);
            }
            else {
                SASSERT(!is_quantifier(new_expr));
                result = m_manager.update_quantifier(q, new_expr);
            }
        }

        void pull_quant1(expr * n, expr_ref & result) {
            if (is_app(n))
                pull_quant1(to_app(n)->get_decl(), to_app(n)->get_num_args(), to_app(n)->get_args(), result);
            else if (is_quantifier(n))
                pull_quant1(to_quantifier(n), to_quantifier(n)->get_expr(), result);
            else
                result = n;
        }
        
        // Code for proof generation...
        void pull_quant2(expr * n, expr_ref & r, proof_ref & pr) {
            pr = 0;
            if (is_app(n)) {
                expr_ref_buffer   new_args(m_manager);
                expr_ref          new_arg(m_manager);
                ptr_buffer<proof> proofs;
                unsigned num = to_app(n)->get_num_args();
                for (unsigned i = 0; i < num; i++) {
                    expr * arg = to_app(n)->get_arg(i); 
                    pull_quant1(arg , new_arg);
                    new_args.push_back(new_arg);
                    if (new_arg != arg)
                        proofs.push_back(m_manager.mk_pull_quant(arg, to_quantifier(new_arg)));
                }
                pull_quant1(to_app(n)->get_decl(), new_args.size(), new_args.c_ptr(), r);
                if (m_manager.fine_grain_proofs()) {
                    app   * r1 = m_manager.mk_app(to_app(n)->get_decl(), new_args.size(), new_args.c_ptr());
                    proof * p1 = proofs.empty() ? 0 : m_manager.mk_congruence(to_app(n), r1, proofs.size(), proofs.c_ptr());
                    proof * p2 = r1 == r ? 0 : m_manager.mk_pull_quant(r1, to_quantifier(r));
                    pr = m_manager.mk_transitivity(p1, p2);
                }
            }
            else if (is_quantifier(n)) {
                expr_ref new_expr(m_manager);
                pull_quant1(to_quantifier(n)->get_expr(), new_expr);
                pull_quant1(to_quantifier(n), new_expr, r);
                if (m_manager.fine_grain_proofs()) {
                    quantifier * q1 = m_manager.update_quantifier(to_quantifier(n), new_expr);
                    proof * p1 = 0;
                    if (n != q1) {
                        proof * p0 = m_manager.mk_pull_quant(to_quantifier(n)->get_expr(), to_quantifier(new_expr));
                        p1 = m_manager.mk_quant_intro(to_quantifier(n), q1, p0);
                    }
                    proof * p2 = q1 == r ? 0 : m_manager.mk_pull_quant(q1, to_quantifier(r));
                    pr = m_manager.mk_transitivity(p1, p2);
                }
            }
            else {
                r  = n;
            }
        }

        br_status reduce_app(func_decl * f, unsigned num, expr * const * args, expr_ref & result, proof_ref & result_pr) {
            if (!m_manager.is_or(f) && !m_manager.is_and(f) && !m_manager.is_not(f))
                return BR_FAILED;

            if (!pull_quant1_core(f, num, args, result))
                return BR_FAILED;

            if (m_manager.proofs_enabled()) {
                result_pr = m_manager.mk_pull_quant(m_manager.mk_app(f, num, args), 
                                                    to_quantifier(result.get()));
            }
            return BR_DONE;
        }

        bool reduce_quantifier(quantifier * old_q, 
                               expr * new_body, 
                               expr * const * new_patterns, 
                               expr * const * new_no_patterns,
                               expr_ref & result,
                               proof_ref & result_pr) {

            if (old_q->is_exists()) {
                UNREACHABLE();
                return false;
            }

            if (!is_forall(new_body))
                return false;

            pull_quant1_core(old_q, new_body, result);
            if (m_manager.proofs_enabled())
                result_pr = m_manager.mk_pull_quant(old_q, to_quantifier(result.get()));
            return true;
        }
    };

    struct rw : public rewriter_tpl<rw_cfg> {
        rw_cfg m_cfg;
        rw(ast_manager & m):
            rewriter_tpl<rw_cfg>(m, m.proofs_enabled(), m_cfg),
            m_cfg(m) {
        }
    };
    
    rw m_rw;

    imp(ast_manager & m):
        m_rw(m) {
    }
    
    void operator()(expr * n, expr_ref & r, proof_ref & p) {
        m_rw(n, r, p);
    }
};

pull_quant::pull_quant(ast_manager & m) {
    m_imp = alloc(imp, m);
}

pull_quant::~pull_quant() {
    dealloc(m_imp);
}

void pull_quant::operator()(expr * n, expr_ref & r, proof_ref & p) {
    (*m_imp)(n, r, p);
}

void pull_quant::reset() {
    m_imp->m_rw.reset();
}

void pull_quant::pull_quant2(expr * n, expr_ref & r, proof_ref & pr) {
    m_imp->m_rw.cfg().pull_quant2(n, r, pr);
}

struct pull_nested_quant::imp {
    
    struct rw_cfg : public default_rewriter_cfg {
        pull_quant m_pull;
        expr_ref   m_r;
        proof_ref  m_pr;

        rw_cfg(ast_manager & m):m_pull(m), m_r(m), m_pr(m) {}
        
        bool get_subst(expr * s, expr * & t, proof * & t_pr) { 
            if (!is_quantifier(s))
                return false;
            m_pull(to_quantifier(s), m_r, m_pr);
            t    = m_r.get();
            t_pr = m_pr.get();
            return true;
        }
    };
    
    struct rw : public rewriter_tpl<rw_cfg> {
        rw_cfg m_cfg;
        rw(ast_manager & m):
            rewriter_tpl<rw_cfg>(m, m.proofs_enabled(), m_cfg),
            m_cfg(m) {
        }
    };
    
    rw m_rw;

    imp(ast_manager & m):
        m_rw(m) {
    }
    
    void operator()(expr * n, expr_ref & r, proof_ref & p) {
        m_rw(n, r, p);
    }
};

pull_nested_quant::pull_nested_quant(ast_manager & m) {
    m_imp = alloc(imp, m);
}

pull_nested_quant::~pull_nested_quant() {
    dealloc(m_imp);
}

void pull_nested_quant::operator()(expr * n, expr_ref & r, proof_ref & p) {
    (*m_imp)(n, r, p);
}

void pull_nested_quant::reset() {
    m_imp->m_rw.reset();
}


