diff options
Diffstat (limited to 'fs/userfaultfd.c')
-rw-r--r-- | fs/userfaultfd.c | 63 |
1 files changed, 39 insertions, 24 deletions
diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index 66cdb44616d5..85959d8324df 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -137,7 +137,7 @@ static void userfaultfd_ctx_put(struct userfaultfd_ctx *ctx) VM_BUG_ON(waitqueue_active(&ctx->fault_wqh)); VM_BUG_ON(spin_is_locked(&ctx->fd_wqh.lock)); VM_BUG_ON(waitqueue_active(&ctx->fd_wqh)); - mmput(ctx->mm); + mmdrop(ctx->mm); kmem_cache_free(userfaultfd_ctx_cachep, ctx); } } @@ -257,10 +257,9 @@ out: * fatal_signal_pending()s, and the mmap_sem must be released before * returning it. */ -int handle_userfault(struct vm_area_struct *vma, unsigned long address, - unsigned int flags, unsigned long reason) +int handle_userfault(struct fault_env *fe, unsigned long reason) { - struct mm_struct *mm = vma->vm_mm; + struct mm_struct *mm = fe->vma->vm_mm; struct userfaultfd_ctx *ctx; struct userfaultfd_wait_queue uwq; int ret; @@ -269,7 +268,7 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address, BUG_ON(!rwsem_is_locked(&mm->mmap_sem)); ret = VM_FAULT_SIGBUS; - ctx = vma->vm_userfaultfd_ctx.ctx; + ctx = fe->vma->vm_userfaultfd_ctx.ctx; if (!ctx) goto out; @@ -302,17 +301,17 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address, * without first stopping userland access to the memory. For * VM_UFFD_MISSING userfaults this is enough for now. */ - if (unlikely(!(flags & FAULT_FLAG_ALLOW_RETRY))) { + if (unlikely(!(fe->flags & FAULT_FLAG_ALLOW_RETRY))) { /* * Validate the invariant that nowait must allow retry * to be sure not to return SIGBUS erroneously on * nowait invocations. */ - BUG_ON(flags & FAULT_FLAG_RETRY_NOWAIT); + BUG_ON(fe->flags & FAULT_FLAG_RETRY_NOWAIT); #ifdef CONFIG_DEBUG_VM if (printk_ratelimit()) { printk(KERN_WARNING - "FAULT_FLAG_ALLOW_RETRY missing %x\n", flags); + "FAULT_FLAG_ALLOW_RETRY missing %x\n", fe->flags); dump_stack(); } #endif @@ -324,7 +323,7 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address, * and wait. */ ret = VM_FAULT_RETRY; - if (flags & FAULT_FLAG_RETRY_NOWAIT) + if (fe->flags & FAULT_FLAG_RETRY_NOWAIT) goto out; /* take the reference before dropping the mmap_sem */ @@ -332,10 +331,11 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address, init_waitqueue_func_entry(&uwq.wq, userfaultfd_wake_function); uwq.wq.private = current; - uwq.msg = userfault_msg(address, flags, reason); + uwq.msg = userfault_msg(fe->address, fe->flags, reason); uwq.ctx = ctx; - return_to_userland = (flags & (FAULT_FLAG_USER|FAULT_FLAG_KILLABLE)) == + return_to_userland = + (fe->flags & (FAULT_FLAG_USER|FAULT_FLAG_KILLABLE)) == (FAULT_FLAG_USER|FAULT_FLAG_KILLABLE); spin_lock(&ctx->fault_pending_wqh.lock); @@ -353,7 +353,7 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address, TASK_KILLABLE); spin_unlock(&ctx->fault_pending_wqh.lock); - must_wait = userfaultfd_must_wait(ctx, address, flags, reason); + must_wait = userfaultfd_must_wait(ctx, fe->address, fe->flags, reason); up_read(&mm->mmap_sem); if (likely(must_wait && !ACCESS_ONCE(ctx->released) && @@ -434,6 +434,9 @@ static int userfaultfd_release(struct inode *inode, struct file *file) ACCESS_ONCE(ctx->released) = true; + if (!mmget_not_zero(mm)) + goto wakeup; + /* * Flush page faults out of all CPUs. NOTE: all page faults * must be retried without returning VM_FAULT_SIGBUS if @@ -466,7 +469,8 @@ static int userfaultfd_release(struct inode *inode, struct file *file) vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; } up_write(&mm->mmap_sem); - + mmput(mm); +wakeup: /* * After no new page faults can wait on this fault_*wqh, flush * the last page faults that may have been already waiting on @@ -760,10 +764,12 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, start = uffdio_register.range.start; end = start + uffdio_register.range.len; + ret = -ENOMEM; + if (!mmget_not_zero(mm)) + goto out; + down_write(&mm->mmap_sem); vma = find_vma_prev(mm, start, &prev); - - ret = -ENOMEM; if (!vma) goto out_unlock; @@ -864,6 +870,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, } while (vma && vma->vm_start < end); out_unlock: up_write(&mm->mmap_sem); + mmput(mm); if (!ret) { /* * Now that we scanned all vmas we can already tell @@ -902,10 +909,12 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, start = uffdio_unregister.start; end = start + uffdio_unregister.len; + ret = -ENOMEM; + if (!mmget_not_zero(mm)) + goto out; + down_write(&mm->mmap_sem); vma = find_vma_prev(mm, start, &prev); - - ret = -ENOMEM; if (!vma) goto out_unlock; @@ -998,6 +1007,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, } while (vma && vma->vm_start < end); out_unlock: up_write(&mm->mmap_sem); + mmput(mm); out: return ret; } @@ -1067,9 +1077,11 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx, goto out; if (uffdio_copy.mode & ~UFFDIO_COPY_MODE_DONTWAKE) goto out; - - ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src, - uffdio_copy.len); + if (mmget_not_zero(ctx->mm)) { + ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src, + uffdio_copy.len); + mmput(ctx->mm); + } if (unlikely(put_user(ret, &user_uffdio_copy->copy))) return -EFAULT; if (ret < 0) @@ -1110,8 +1122,11 @@ static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx, if (uffdio_zeropage.mode & ~UFFDIO_ZEROPAGE_MODE_DONTWAKE) goto out; - ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start, - uffdio_zeropage.range.len); + if (mmget_not_zero(ctx->mm)) { + ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start, + uffdio_zeropage.range.len); + mmput(ctx->mm); + } if (unlikely(put_user(ret, &user_uffdio_zeropage->zeropage))) return -EFAULT; if (ret < 0) @@ -1289,12 +1304,12 @@ static struct file *userfaultfd_file_create(int flags) ctx->released = false; ctx->mm = current->mm; /* prevent the mm struct to be freed */ - atomic_inc(&ctx->mm->mm_users); + atomic_inc(&ctx->mm->mm_count); file = anon_inode_getfile("[userfaultfd]", &userfaultfd_fops, ctx, O_RDWR | (flags & UFFD_SHARED_FCNTL_FLAGS)); if (IS_ERR(file)) { - mmput(ctx->mm); + mmdrop(ctx->mm); kmem_cache_free(userfaultfd_ctx_cachep, ctx); } out: |