buildbrain/tools/aptcache.go
2020-10-27 19:52:59 +09:00

257 lines
5.6 KiB
Go

// See https://github.com/puhitaku/empera for the origin
package main
import (
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"time"
)
type Proxy struct {
remote string
cli *http.Client
cache map[string]struct{}
cacheLock sync.Mutex
}
func NewProxy() (*Proxy, error) {
p := &Proxy{
cli: http.DefaultClient,
cache: map[string]struct{}{},
}
stat, err := os.Stat("cache")
if err != nil {
if err.(*os.PathError).Err != syscall.ENOENT {
return nil, fmt.Errorf("failed to stat cache directory: %s", err)
}
err = os.Mkdir("cache", 0755)
if err != nil {
return nil, fmt.Errorf("failed to create cache directory: %s", err)
}
} else if !stat.IsDir() {
return nil, fmt.Errorf("non-directory 'cache' exists")
}
matches, err := filepath.Glob("cache/*")
if err != nil {
return nil, fmt.Errorf("failed to glob cache directory: %s", err)
}
for i := range matches {
p.cache[strings.TrimPrefix(matches[i], "cache/")] = struct{}{}
}
return p, nil
}
func (p *Proxy) Run(local, remote string) {
p.remote = remote
err := http.ListenAndServe(local, p)
if err != nil {
panic(err)
}
}
// ServeHTTP implements http.Handler interface
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var err error
encoded := url.PathEscape(r.URL.Path)
exclude := []string{"Release", "Packages", "Contents"}
nocache := false
for _, ex := range exclude {
nocache = nocache || strings.Contains(encoded, ex)
}
if nocache {
fmt.Printf("GET (no cache): %s%s -> ", p.remote, r.URL.Path)
err = p.fetchFromRemote(w, r, false)
} else if _, ok := p.cache[encoded]; ok {
fmt.Printf("GET (cache hit): %s%s -> ", p.remote, r.URL.Path)
err = p.fetchFromCache(w, r)
} else {
fmt.Printf("GET (cache miss): %s%s -> ", p.remote, r.URL.Path)
err = p.fetchFromRemote(w, r, true)
}
if err != nil {
fmt.Printf("%s\n", err)
} else {
fmt.Printf("200\n")
}
}
func (p *Proxy) fetchFromRemote(w http.ResponseWriter, r *http.Request, cache bool) error {
var f io.WriteCloser = NullWriter{}
var err error
encoded := url.PathEscape(r.URL.Path)
fpath := filepath.Join("cache", encoded)
newURL, err := url.Parse(r.URL.String())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return fmt.Errorf("failed to parse URL: %s", err)
}
newURL.Scheme = "http"
newURL.Host = p.remote
req, err := http.NewRequest(http.MethodGet, newURL.String(), nil)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return fmt.Errorf("failed to create a new request: %s", err)
}
req.Header = r.Header
res, err := p.cli.Do(req)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return fmt.Errorf("failed to GET: %s", err)
}
defer res.Body.Close()
_, err = os.Stat(fpath)
if err != nil {
if err.(*os.PathError).Err != syscall.ENOENT {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return fmt.Errorf("failed to stat cached file: %s", err)
}
}
if cache && res.StatusCode == http.StatusOK {
f, err = os.Create(fpath)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return fmt.Errorf("failed to create file: %s", err)
}
defer f.Close()
}
for k, vs := range res.Header {
for _, v := range vs {
w.Header().Add(k, v)
}
}
w.WriteHeader(res.StatusCode)
_, err = io.Copy(w, io.TeeReader(res.Body, f))
if err != nil {
return fmt.Errorf("failed to copy: %s", err)
}
if res.StatusCode == http.StatusOK {
p.cacheLock.Lock()
defer p.cacheLock.Unlock()
p.cache[encoded] = struct{}{}
return nil
} else {
return fmt.Errorf(strconv.Itoa(res.StatusCode))
}
}
func (p *Proxy) fetchFromCache(w http.ResponseWriter, r *http.Request) error {
encoded := url.PathEscape(r.URL.Path)
f, err := os.Open(filepath.Join("cache", encoded))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return fmt.Errorf("failed to open '%s': %s", encoded, err)
}
defer f.Close()
_, err = io.Copy(w, f)
if err != nil {
return fmt.Errorf("failed to copy: %s", err)
}
return nil
}
type NullWriter struct{}
func (w NullWriter) Write(b []byte) (int, error) {
return len(b), nil
}
func (w NullWriter) Close() error {
return nil
}
type rule struct {
Local, Remote string
}
type rules []rule
func (r *rules) String() string {
return ""
}
func (r *rules) Set(raw string) error {
var local, remote string
kvs := strings.Split(raw, ",")
for _, kv := range kvs {
tokens := strings.Split(kv, "=")
if len(tokens) != 2 {
return fmt.Errorf("rule is malformed")
}
tokens[0], tokens[1] = strings.TrimSpace(tokens[0]), strings.TrimSpace(tokens[1])
switch tokens[0] {
case "local":
local = tokens[1]
case "remote":
remote = tokens[1]
default:
return fmt.Errorf("rule has unknown key: '%s'", tokens[0])
}
}
if local == "" || remote == "" {
return fmt.Errorf("rule lacks mendatory keys: 'local' and/or 'remote'")
}
*r = append(*r, rule{Local: local, Remote: remote})
return nil
}
func main() {
var rules rules
flag.Var(&rules, "rule", "Proxy rule. example: -rule 'local=localhost:8080, remote=super.slow.repository.example.com'")
flag.Parse()
if len(rules) == 0 {
fmt.Fprintf(os.Stderr, "Fatal: specify one or more rules.\n")
flag.Usage()
os.Exit(1)
}
for i, rule := range rules {
fmt.Printf("Proxy Rule %d: %s -> %s\n", i+1, rule.Local, rule.Remote)
p, err := NewProxy()
if err != nil {
panic(err)
}
go p.Run(rule.Local, rule.Remote)
}
for {
time.Sleep(9999999999)
}
}