package rewrite import ( "github.com/miekg/dns" ) // RevertPolicy controls the overall reverting process type RevertPolicy interface { DoRevert() bool DoQuestionRestore() bool } type revertPolicy struct { noRevert bool noRestore bool } func (p revertPolicy) DoRevert() bool { return !p.noRevert } func (p revertPolicy) DoQuestionRestore() bool { return !p.noRestore } // NoRevertPolicy disables all response rewrite rules func NoRevertPolicy() RevertPolicy { return revertPolicy{true, false} } // NoRestorePolicy disables the question restoration during the response rewrite func NoRestorePolicy() RevertPolicy { return revertPolicy{false, true} } // NewRevertPolicy creates a new reverter policy by dynamically specifying all // options. func NewRevertPolicy(noRevert, noRestore bool) RevertPolicy { return revertPolicy{noRestore: noRestore, noRevert: noRevert} } // ResponseRule contains a rule to rewrite a response with. type ResponseRule interface { RewriteResponse(res *dns.Msg, rr dns.RR) } // ResponseRules describes an ordered list of response rules to apply // after a name rewrite type ResponseRules = []ResponseRule // ResponseReverter reverses the operations done on the question section of a packet. // This is need because the client will otherwise disregards the response, i.e. // dig will complain with ';; Question section mismatch: got example.org/HINFO/IN' type ResponseReverter struct { dns.ResponseWriter originalQuestion dns.Question ResponseRules ResponseRules revertPolicy RevertPolicy } // NewResponseReverter returns a pointer to a new ResponseReverter. func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg, policy RevertPolicy) *ResponseReverter { return &ResponseReverter{ ResponseWriter: w, originalQuestion: r.Question[0], revertPolicy: policy, } } // WriteMsg records the status code and calls the underlying ResponseWriter's WriteMsg method. func (r *ResponseReverter) WriteMsg(res1 *dns.Msg) error { // Deep copy 'res' as to not (e.g). rewrite a message that's also stored in the cache. res := res1.Copy() if r.revertPolicy.DoQuestionRestore() { res.Question[0] = r.originalQuestion } if len(r.ResponseRules) > 0 { for _, rr := range res.Ns { r.rewriteResourceRecord(res, rr) } for _, rr := range res.Answer { r.rewriteResourceRecord(res, rr) } for _, rr := range res.Extra { r.rewriteResourceRecord(res, rr) } } return r.ResponseWriter.WriteMsg(res) } func (r *ResponseReverter) rewriteResourceRecord(res *dns.Msg, rr dns.RR) { // The reverting rules need to be done in reversed order. for i := len(r.ResponseRules) - 1; i >= 0; i-- { r.ResponseRules[i].RewriteResponse(res, rr) } } // Write is a wrapper that records the size of the message that gets written. func (r *ResponseReverter) Write(buf []byte) (int, error) { n, err := r.ResponseWriter.Write(buf) return n, err } func getRecordValueForRewrite(rr dns.RR) (name string) { switch rr.Header().Rrtype { case dns.TypeSRV: return rr.(*dns.SRV).Target case dns.TypeMX: return rr.(*dns.MX).Mx case dns.TypeCNAME: return rr.(*dns.CNAME).Target case dns.TypeNS: return rr.(*dns.NS).Ns case dns.TypeDNAME: return rr.(*dns.DNAME).Target case dns.TypeNAPTR: return rr.(*dns.NAPTR).Replacement case dns.TypeSOA: return rr.(*dns.SOA).Ns case dns.TypePTR: return rr.(*dns.PTR).Ptr default: return "" } } func setRewrittenRecordValue(rr dns.RR, value string) { switch rr.Header().Rrtype { case dns.TypeSRV: rr.(*dns.SRV).Target = value case dns.TypeMX: rr.(*dns.MX).Mx = value case dns.TypeCNAME: rr.(*dns.CNAME).Target = value case dns.TypeNS: rr.(*dns.NS).Ns = value case dns.TypeDNAME: rr.(*dns.DNAME).Target = value case dns.TypeNAPTR: rr.(*dns.NAPTR).Replacement = value case dns.TypeSOA: rr.(*dns.SOA).Ns = value case dns.TypePTR: rr.(*dns.PTR).Ptr = value } }