aboutsummaryrefslogtreecommitdiff
path: root/mm/mempolicy.c
diff options
context:
space:
mode:
Diffstat (limited to 'mm/mempolicy.c')
-rw-r--r--mm/mempolicy.c48
1 files changed, 28 insertions, 20 deletions
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index da858f794eb6..af171ccb56a2 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -350,7 +350,7 @@ static void mpol_rebind_policy(struct mempolicy *pol, const nodemask_t *newmask)
{
if (!pol)
return;
- if (!mpol_store_user_nodemask(pol) &&
+ if (!mpol_store_user_nodemask(pol) && !(pol->flags & MPOL_F_LOCAL) &&
nodes_equal(pol->w.cpuset_mems_allowed, *newmask))
return;
@@ -797,16 +797,19 @@ static void get_policy_nodemask(struct mempolicy *p, nodemask_t *nodes)
}
}
-static int lookup_node(unsigned long addr)
+static int lookup_node(struct mm_struct *mm, unsigned long addr)
{
struct page *p;
int err;
- err = get_user_pages(addr & PAGE_MASK, 1, 0, &p, NULL);
+ int locked = 1;
+ err = get_user_pages_locked(addr & PAGE_MASK, 1, 0, &p, &locked);
if (err >= 0) {
err = page_to_nid(p);
put_page(p);
}
+ if (locked)
+ up_read(&mm->mmap_sem);
return err;
}
@@ -817,7 +820,7 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
int err;
struct mm_struct *mm = current->mm;
struct vm_area_struct *vma = NULL;
- struct mempolicy *pol = current->mempolicy;
+ struct mempolicy *pol = current->mempolicy, *pol_refcount = NULL;
if (flags &
~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR|MPOL_F_MEMS_ALLOWED))
@@ -857,7 +860,16 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
if (flags & MPOL_F_NODE) {
if (flags & MPOL_F_ADDR) {
- err = lookup_node(addr);
+ /*
+ * Take a refcount on the mpol, lookup_node()
+ * wil drop the mmap_sem, so after calling
+ * lookup_node() only "pol" remains valid, "vma"
+ * is stale.
+ */
+ pol_refcount = pol;
+ vma = NULL;
+ mpol_get(pol);
+ err = lookup_node(mm, addr);
if (err < 0)
goto out;
*policy = err;
@@ -892,7 +904,9 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
out:
mpol_cond_put(pol);
if (vma)
- up_read(&current->mm->mmap_sem);
+ up_read(&mm->mmap_sem);
+ if (pol_refcount)
+ mpol_put(pol_refcount);
return err;
}
@@ -1300,7 +1314,7 @@ static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode,
nodemask_t *nodes)
{
unsigned long copy = ALIGN(maxnode-1, 64) / 8;
- const int nbytes = BITS_TO_LONGS(MAX_NUMNODES) * sizeof(long);
+ unsigned int nbytes = BITS_TO_LONGS(nr_node_ids) * sizeof(long);
if (copy > nbytes) {
if (copy > PAGE_SIZE)
@@ -1477,7 +1491,7 @@ static int kernel_get_mempolicy(int __user *policy,
int uninitialized_var(pval);
nodemask_t nodes;
- if (nmask != NULL && maxnode < MAX_NUMNODES)
+ if (nmask != NULL && maxnode < nr_node_ids)
return -EINVAL;
err = do_get_mempolicy(&pval, &nodes, addr, flags);
@@ -1513,7 +1527,7 @@ COMPAT_SYSCALL_DEFINE5(get_mempolicy, int __user *, policy,
unsigned long nr_bits, alloc_size;
DECLARE_BITMAP(bm, MAX_NUMNODES);
- nr_bits = min_t(unsigned long, maxnode-1, MAX_NUMNODES);
+ nr_bits = min_t(unsigned long, maxnode-1, nr_node_ids);
alloc_size = ALIGN(nr_bits, BITS_PER_LONG) / 8;
if (nmask)
@@ -2039,8 +2053,7 @@ alloc_pages_vma(gfp_t gfp, int order, struct vm_area_struct *vma,
* If the policy is interleave, or does not allow the current
* node in its nodemask, we allocate the standard way.
*/
- if (pol->mode == MPOL_PREFERRED &&
- !(pol->flags & MPOL_F_LOCAL))
+ if (pol->mode == MPOL_PREFERRED && !(pol->flags & MPOL_F_LOCAL))
hpage_node = pol->v.preferred_node;
nmask = policy_nodemask(gfp, pol);
@@ -2291,7 +2304,7 @@ int mpol_misplaced(struct page *page, struct vm_area_struct *vma, unsigned long
unsigned long pgoff;
int thiscpu = raw_smp_processor_id();
int thisnid = cpu_to_node(thiscpu);
- int polnid = -1;
+ int polnid = NUMA_NO_NODE;
int ret = -1;
pol = get_vma_policy(vma, addr);
@@ -2697,12 +2710,11 @@ static const char * const policy_modes[] =
int mpol_parse_str(char *str, struct mempolicy **mpol)
{
struct mempolicy *new = NULL;
- unsigned short mode;
unsigned short mode_flags;
nodemask_t nodes;
char *nodelist = strchr(str, ':');
char *flags = strchr(str, '=');
- int err = 1;
+ int err = 1, mode;
if (nodelist) {
/* NUL-terminate mode or flags string */
@@ -2717,12 +2729,8 @@ int mpol_parse_str(char *str, struct mempolicy **mpol)
if (flags)
*flags++ = '\0'; /* terminate mode string */
- for (mode = 0; mode < MPOL_MAX; mode++) {
- if (!strcmp(str, policy_modes[mode])) {
- break;
- }
- }
- if (mode >= MPOL_MAX)
+ mode = match_string(policy_modes, MPOL_MAX, str);
+ if (mode < 0)
goto out;
switch (mode) {